3 # Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
18 with warnings.catch_warnings():
19 warnings.filterwarnings("ignore", category=FutureWarning)
20 import tensorflow as tf
22 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
28 # cmd arguments parsing
30 script = os.path.basename(os.path.basename(__file__))
31 print("Usage: {} path_to_pb".format(script))
35 if __name__ == '__main__':
36 parser = argparse.ArgumentParser()
40 help='path to modelfile in either graph_def (.pb) or tflite (.tflite)')
42 '-o', '--output', action='store', dest="out_dir", help="output directory")
43 args = parser.parse_args()
45 if len(sys.argv) == 1:
49 filename = args.modelfile
52 out_dir = args.out_dir + '/'
56 _, extension = os.path.splitext(filename)
65 if extension == ".pb":
66 # import graph_def (pb)
67 graph = tf.compat.v1.get_default_graph()
68 graph_def = tf.compat.v1.GraphDef()
70 with tf.io.gfile.GFile(filename, 'rb') as f:
71 graph_def.ParseFromString(f.read())
72 tf.import_graph_def(graph_def, name='')
74 # identify input namess and output names
75 ops = graph.get_operations()
76 input_names = [op.outputs[0].name for op in ops if op.type == "Placeholder"]
77 output_names = [tensor.name for op in ops for tensor in op.outputs]
80 if t.name in output_names:
81 output_names.remove(t.name)
83 # identify input dtypes and output dtypes
84 input_dtypes = [graph.get_tensor_by_name(name).dtype for name in input_names]
85 output_dtypes = [graph.get_tensor_by_name(name).dtype for name in output_names]
87 # gen random input values
88 for idx in range(len(input_names)):
89 this_shape = graph.get_tensor_by_name(input_names[idx]).shape
90 this_dtype = input_dtypes[idx]
91 if this_dtype == tf.uint8:
93 np.random.randint(0, 255, this_shape).astype(np.uint8))
94 elif this_dtype == tf.float32:
96 np.random.random_sample(this_shape).astype(np.float32))
97 elif this_dtype == tf.bool:
98 # generate random integer from [0, 2)
100 np.random.randint(2, size=this_shape).astype(np.bool_))
101 elif this_dtype == tf.int32:
102 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int32))
103 elif this_dtype == tf.int64:
104 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int64))
106 # get output values by running
107 config = tf.compat.v1.ConfigProto()
108 config.gpu_options.allow_growth = True
109 with tf.compat.v1.Session(config=config) as sess:
110 output_values = sess.run(
111 output_names, feed_dict=dict(zip(input_names, input_values)))
113 elif extension == ".tflite":
114 # load TFLite model and allocate tensors
115 interpreter = tf.lite.Interpreter(filename)
116 interpreter.allocate_tensors()
118 # get list of tensors details for input/output
119 input_details = interpreter.get_input_details()
120 output_details = interpreter.get_output_details()
122 # identify input namess and output names
123 input_names = [d['name'] for d in input_details]
124 output_names = [d['name'] for d in output_details]
126 # identify input dtypes and output dtypes
127 input_dtypes = [d['dtype'] for d in input_details]
128 output_dtypes = [d['dtype'] for d in output_details]
130 # gen random input values and set tensor
131 for idx in range(len(input_details)):
132 this_shape = input_details[idx]['shape']
133 this_dtype = input_details[idx]['dtype']
134 if this_dtype == np.uint8:
136 np.random.randint(0, 255, this_shape).astype(np.uint8))
137 elif this_dtype == np.float32:
139 np.random.random_sample(this_shape).astype(np.float32))
140 elif this_dtype == np.bool_:
141 # generate random integer from [0, 2)
143 np.random.randint(2, size=this_shape).astype(np.bool_))
144 elif this_dtype == np.int32:
145 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int32))
146 elif this_dtype == np.int64:
147 input_values.append(np.random.randint(0, 99, this_shape).astype(np.int64))
148 interpreter.set_tensor(input_details[idx]['index'], input_values[idx])
150 # get output values by running
152 for idx in range(len(output_details)):
153 output_values.append(interpreter.get_tensor(output_details[idx]['index']))
156 print("Only .pb and .tflite models are supported.")
159 # dump input and output in h5
161 supported_dtypes = ("float32", "uint8", "bool", "int32", "int64")
169 with h5py.File(out_dir + "input.h5", 'w') as hf:
170 name_grp = hf.create_group("name")
171 val_grp = hf.create_group("value")
172 for idx, t in enumerate(input_names):
173 dtype = tf.compat.v1.as_dtype(input_dtypes[idx])
174 if not dtype.name in supported_dtypes:
175 print("ERR: Supported input types are {}".format(supported_dtypes))
177 val_grp.create_dataset(
178 str(idx), data=input_values[idx], dtype=h5dtypes[dtype.name])
179 name_grp.attrs[str(idx)] = input_names[idx]
181 with h5py.File(out_dir + "expected.h5", 'w') as hf:
182 name_grp = hf.create_group("name")
183 val_grp = hf.create_group("value")
184 for idx, t in enumerate(output_names):
185 dtype = tf.compat.v1.as_dtype(output_dtypes[idx])
186 if not dtype.name in supported_dtypes:
187 print("ERR: Supported output types are {}".format(supported_dtypes))
189 val_grp.create_dataset(
190 str(idx), data=output_values[idx], dtype=h5dtypes[dtype.name])
191 name_grp.attrs[str(idx)] = output_names[idx]