Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / tools / nnpackage_tool / gen_golden / gen_golden.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import warnings
18 with warnings.catch_warnings():
19     warnings.filterwarnings("ignore", category=FutureWarning)
20     import tensorflow as tf
21 import os
22 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
23 import sys
24 import argparse
25 import numpy as np
26
27
28 # cmd arguments parsing
29 def usage():
30     script = os.path.basename(os.path.basename(__file__))
31     print("Usage: {} path_to_pb".format(script))
32     sys.exit(-1)
33
34
35 if __name__ == '__main__':
36     parser = argparse.ArgumentParser()
37     parser.add_argument(
38         'modelfile',
39         type=str,
40         help='path to modelfile in either graph_def (.pb) or tflite (.tflite)')
41     parser.add_argument(
42         '-o', '--output', action='store', dest="out_dir", help="output directory")
43     args = parser.parse_args()
44
45     if len(sys.argv) == 1:
46         parser.parse_args()
47         sys.exit(1)
48
49     filename = args.modelfile
50
51     if args.out_dir:
52         out_dir = args.out_dir + '/'
53     else:
54         out_dir = "./"
55
56     _, extension = os.path.splitext(filename)
57
58     input_names = []
59     output_names = []
60     input_dtypes = []
61     output_dtypes = []
62     input_values = []
63     output_values = []
64
65     if extension == ".pb":
66         # import graph_def (pb)
67         graph = tf.compat.v1.get_default_graph()
68         graph_def = tf.compat.v1.GraphDef()
69
70         with tf.io.gfile.GFile(filename, 'rb') as f:
71             graph_def.ParseFromString(f.read())
72             tf.import_graph_def(graph_def, name='')
73
74         # identify input namess and output names
75         ops = graph.get_operations()
76         input_names = [op.outputs[0].name for op in ops if op.type == "Placeholder"]
77         output_names = [tensor.name for op in ops for tensor in op.outputs]
78         for op in ops:
79             for t in op.inputs:
80                 if t.name in output_names:
81                     output_names.remove(t.name)
82
83         # identify input dtypes and output dtypes
84         input_dtypes = [graph.get_tensor_by_name(name).dtype for name in input_names]
85         output_dtypes = [graph.get_tensor_by_name(name).dtype for name in output_names]
86
87         # gen random input values
88         for idx in range(len(input_names)):
89             this_shape = graph.get_tensor_by_name(input_names[idx]).shape
90             this_dtype = input_dtypes[idx]
91             if this_dtype == tf.uint8:
92                 input_values.append(
93                     np.random.randint(0, 255, this_shape).astype(np.uint8))
94             elif this_dtype == tf.float32:
95                 input_values.append(
96                     np.random.random_sample(this_shape).astype(np.float32))
97             elif this_dtype == tf.bool:
98                 # generate random integer from [0, 2)
99                 input_values.append(
100                     np.random.randint(2, size=this_shape).astype(np.bool_))
101             elif this_dtype == tf.int32:
102                 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int32))
103             elif this_dtype == tf.int64:
104                 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int64))
105
106         # get output values by running
107         config = tf.compat.v1.ConfigProto()
108         config.gpu_options.allow_growth = True
109         with tf.compat.v1.Session(config=config) as sess:
110             output_values = sess.run(
111                 output_names, feed_dict=dict(zip(input_names, input_values)))
112
113     elif extension == ".tflite":
114         # load TFLite model and allocate tensors
115         interpreter = tf.lite.Interpreter(filename)
116         interpreter.allocate_tensors()
117
118         # get list of tensors details for input/output
119         input_details = interpreter.get_input_details()
120         output_details = interpreter.get_output_details()
121
122         # identify input namess and output names
123         input_names = [d['name'] for d in input_details]
124         output_names = [d['name'] for d in output_details]
125
126         # identify input dtypes and output dtypes
127         input_dtypes = [d['dtype'] for d in input_details]
128         output_dtypes = [d['dtype'] for d in output_details]
129
130         # gen random input values and set tensor
131         for idx in range(len(input_details)):
132             this_shape = input_details[idx]['shape']
133             this_dtype = input_details[idx]['dtype']
134             if this_dtype == np.uint8:
135                 input_values.append(
136                     np.random.randint(0, 255, this_shape).astype(np.uint8))
137             elif this_dtype == np.float32:
138                 input_values.append(
139                     np.random.random_sample(this_shape).astype(np.float32))
140             elif this_dtype == np.bool_:
141                 # generate random integer from [0, 2)
142                 input_values.append(
143                     np.random.randint(2, size=this_shape).astype(np.bool_))
144             elif this_dtype == np.int32:
145                 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int32))
146             elif this_dtype == np.int64:
147                 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int64))
148             interpreter.set_tensor(input_details[idx]['index'], input_values[idx])
149
150         # get output values by running
151         interpreter.invoke()
152         for idx in range(len(output_details)):
153             output_values.append(interpreter.get_tensor(output_details[idx]['index']))
154
155     else:
156         print("Only .pb and .tflite models are supported.")
157         sys.exit(-1)
158
159     # dump input and output in h5
160     import h5py
161     supported_dtypes = ("float32", "uint8", "bool", "int32", "int64")
162     h5dtypes = {
163         "float32": ">f4",
164         "uint8": "u1",
165         "bool": "u1",
166         "int32": "int32",
167         "int64": "int64"
168     }
169     with h5py.File(out_dir + "input.h5", 'w') as hf:
170         name_grp = hf.create_group("name")
171         val_grp = hf.create_group("value")
172         for idx, t in enumerate(input_names):
173             dtype = tf.compat.v1.as_dtype(input_dtypes[idx])
174             if not dtype.name in supported_dtypes:
175                 print("ERR: Supported input types are {}".format(supported_dtypes))
176                 sys.exit(-1)
177             val_grp.create_dataset(
178                 str(idx), data=input_values[idx], dtype=h5dtypes[dtype.name])
179             name_grp.attrs[str(idx)] = input_names[idx]
180
181     with h5py.File(out_dir + "expected.h5", 'w') as hf:
182         name_grp = hf.create_group("name")
183         val_grp = hf.create_group("value")
184         for idx, t in enumerate(output_names):
185             dtype = tf.compat.v1.as_dtype(output_dtypes[idx])
186             if not dtype.name in supported_dtypes:
187                 print("ERR: Supported output types are {}".format(supported_dtypes))
188                 sys.exit(-1)
189             val_grp.create_dataset(
190                 str(idx), data=output_values[idx], dtype=h5dtypes[dtype.name])
191             name_grp.attrs[str(idx)] = output_names[idx]