Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-onnx
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import sys
25 import tempfile
26 import onnx
27 import onnx_tf
28
29 # ONNX legalizer is an optional feature
30 # It enables conversion of some operations, but in experimental phase for now
31 try:
32     import onnx_legalizer
33     _onnx_legalizer_enabled = True
34 except ImportError:
35     _onnx_legalizer_enabled = False
36
37 import onelib.make_cmd as _make_cmd
38 import onelib.utils as oneutils
39
40 # TODO Find better way to suppress trackback on error
41 sys.tracebacklimit = 0
42
43
44 # Class to rename input/output to prevent issues while import ONNX models
45 class TidyIONames:
46     def __init__(self, onnx_model):
47         self.input_nodes = []
48         self.output_nodes = []
49         self.remap_inputs = []
50         self.remap_outputs = []
51         self.initializers = []
52         self.onnx_model = onnx_model
53         # some models may have initializers as inputs. ignore them.
54         for initializer in onnx_model.graph.initializer:
55             self.initializers.append(initializer.name)
56
57     def order(self):
58         for idx in range(0, len(self.onnx_model.graph.input)):
59             name = self.onnx_model.graph.input[idx].name
60             if not name in self.initializers:
61                 self.input_nodes.append(name)
62                 self.remap_inputs.append('i_' + format(idx + 1, '04d') + '_' + name)
63         for idx in range(0, len(self.onnx_model.graph.output)):
64             name = self.onnx_model.graph.output[idx].name
65             self.output_nodes.append(name)
66             self.remap_outputs.append('o_' + format(idx + 1, '04d') + '_' + name)
67
68     # exclude special characters in names
69     def sanitize(self):
70         for idx in range(0, len(self.onnx_model.graph.input)):
71             name = self.onnx_model.graph.input[idx].name
72             if not name in self.initializers:
73                 if '.' in name or ':' in name or name[:1].isdigit():
74                     self.input_nodes.append(name)
75                     name_alt = name.replace('.', '_')
76                     name_alt = name_alt.replace(':', '_')
77                     if name_alt[:1].isdigit():
78                         name_alt = 'a_' + name_alt
79                     self.remap_inputs.append(name_alt)
80         for idx in range(0, len(self.onnx_model.graph.output)):
81             name = self.onnx_model.graph.output[idx].name
82             if '.' in name or ':' in name or name[:1].isdigit():
83                 self.output_nodes.append(name)
84                 name_alt = name.replace('.', '_')
85                 name_alt = name_alt.replace(':', '_')
86                 if name_alt[:1].isdigit():
87                     name_alt = 'a_' + name_alt
88                 self.remap_outputs.append(name_alt)
89
90     def update(self):
91         # change names for graph input
92         for i in range(len(self.onnx_model.graph.input)):
93             if self.onnx_model.graph.input[i].name in self.input_nodes:
94                 to_rename = self.onnx_model.graph.input[i].name
95                 idx = self.input_nodes.index(to_rename)
96                 self.onnx_model.graph.input[i].name = self.remap_inputs[idx]
97         # change names of all nodes in the graph
98         for i in range(len(self.onnx_model.graph.node)):
99             # check node.input is to change to remap_inputs or remap_outputs
100             for j in range(len(self.onnx_model.graph.node[i].input)):
101                 if self.onnx_model.graph.node[i].input[j] in self.input_nodes:
102                     to_rename = self.onnx_model.graph.node[i].input[j]
103                     idx = self.input_nodes.index(to_rename)
104                     self.onnx_model.graph.node[i].input[j] = self.remap_inputs[idx]
105                 if self.onnx_model.graph.node[i].input[j] in self.output_nodes:
106                     to_rename = self.onnx_model.graph.node[i].input[j]
107                     idx = self.output_nodes.index(to_rename)
108                     self.onnx_model.graph.node[i].input[j] = self.remap_outputs[idx]
109             # check node.output is to change to remap_inputs or remap_outputs
110             for j in range(len(self.onnx_model.graph.node[i].output)):
111                 if self.onnx_model.graph.node[i].output[j] in self.output_nodes:
112                     to_rename = self.onnx_model.graph.node[i].output[j]
113                     idx = self.output_nodes.index(to_rename)
114                     self.onnx_model.graph.node[i].output[j] = self.remap_outputs[idx]
115                 if self.onnx_model.graph.node[i].output[j] in self.input_nodes:
116                     to_rename = self.onnx_model.graph.node[i].output[j]
117                     idx = self.input_nodes.index(to_rename)
118                     self.onnx_model.graph.node[i].output[j] = self.remap_inputs[idx]
119         # change names for graph output
120         for i in range(len(self.onnx_model.graph.output)):
121             if self.onnx_model.graph.output[i].name in self.output_nodes:
122                 to_rename = self.onnx_model.graph.output[i].name
123                 idx = self.output_nodes.index(to_rename)
124                 self.onnx_model.graph.output[i].name = self.remap_outputs[idx]
125
126
127 def get_driver_cfg_section():
128     return "one-import-onnx"
129
130
131 def _get_parser():
132     parser = argparse.ArgumentParser(
133         description='command line tool to convert ONNX to circle')
134
135     oneutils.add_default_arg(parser)
136
137     ## tf2tfliteV2 arguments
138     tf2tfliteV2_group = parser.add_argument_group('converter arguments')
139
140     # input and output path.
141     tf2tfliteV2_group.add_argument(
142         '-i', '--input_path', type=str, help='full filepath of the input file')
143     tf2tfliteV2_group.add_argument(
144         '-o', '--output_path', type=str, help='full filepath of the output file')
145
146     # input and output arrays.
147     tf2tfliteV2_group.add_argument(
148         '-I',
149         '--input_arrays',
150         type=str,
151         help='names of the input arrays, comma-separated')
152     tf2tfliteV2_group.add_argument(
153         '-O',
154         '--output_arrays',
155         type=str,
156         help='names of the output arrays, comma-separated')
157
158     # fixed options
159     tf2tfliteV2_group.add_argument('--model_format', default='saved_model')
160     tf2tfliteV2_group.add_argument('--converter_version', default='v2')
161
162     parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
163     parser.add_argument(
164         '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
165     parser.add_argument(
166         '--keep_io_order',
167         action='store_true',
168         help=
169         'Ensure generated circle model preserves the I/O order of the original onnx model.'
170     )
171
172     # save intermediate file(s)
173     parser.add_argument(
174         '--save_intermediate',
175         action='store_true',
176         help='Save intermediate files to output folder')
177
178     # experimental options
179     parser.add_argument(
180         '--experimental_disable_batchmatmul_unfold',
181         action='store_true',
182         help='Experimental disable BatchMatMul unfold')
183
184     return parser
185
186
187 def _verify_arg(parser, args):
188     """verify given arguments"""
189     # check if required arguments is given
190     missing = []
191     if not oneutils.is_valid_attr(args, 'input_path'):
192         missing.append('-i/--input_path')
193     if not oneutils.is_valid_attr(args, 'output_path'):
194         missing.append('-o/--output_path')
195     if len(missing):
196         parser.error('the following arguments are required: ' + ' '.join(missing))
197
198
199 def _parse_arg(parser):
200     args = parser.parse_args()
201     # print version
202     if args.version:
203         oneutils.print_version_and_exit(__file__)
204
205     return args
206
207
208 def _apply_verbosity(verbosity):
209     # NOTE
210     # TF_CPP_MIN_LOG_LEVEL
211     #   0 : INFO + WARNING + ERROR + FATAL
212     #   1 : WARNING + ERROR + FATAL
213     #   2 : ERROR + FATAL
214     #   3 : FATAL
215     if verbosity:
216         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
217     else:
218         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
219
220
221 # TF2.12.1 tries to sanitize special characters, '.:' and maybe others and then fails with
222 # 'IndexError: tuple index out of range' error from somewhere else.
223 # This method is to prevent this IndexError.
224 def _sanitize_io_names(onnx_model):
225     sanitizer = TidyIONames(onnx_model)
226     sanitizer.sanitize()
227     sanitizer.update()
228
229
230 # The index of input/output is added in front of the name. For example,
231 # Original input names: 'a', 'c', 'b'
232 # Renamed: 'i_0001_a', 'i_0002_c', 'i_0003_b'
233 # This will preserve I/O order after import.
234 def _remap_io_names(onnx_model):
235     # gather existing name of I/O and generate new name of I/O in sort order
236     remapper = TidyIONames(onnx_model)
237     remapper.order()
238     remapper.update()
239
240
241 def _check_ext():
242     dir_path = os.path.dirname(os.path.realpath(__file__))
243     ext_path = os.path.join(dir_path, 'one-import-onnx-ext')
244     if (os.path.isfile(ext_path)):
245         return ext_path
246     return None
247
248
249 def _convert(args):
250     _apply_verbosity(args.verbose)
251
252     # get file path to log
253     dir_path = os.path.dirname(os.path.realpath(__file__))
254     logfile_path = os.path.realpath(args.output_path) + '.log'
255     ext_path = _check_ext()
256
257     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
258         # save intermediate
259         if oneutils.is_valid_attr(args, 'save_intermediate'):
260             tmpdir = os.path.dirname(logfile_path)
261         # convert onnx to tf saved model
262         onnx_model = onnx.load(getattr(args, 'input_path'))
263         _sanitize_io_names(onnx_model)
264         if _onnx_legalizer_enabled:
265             options = onnx_legalizer.LegalizeOptions
266             options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
267             options.unroll_lstm = oneutils.is_valid_attr(args, 'unroll_lstm')
268             onnx_legalizer.legalize(onnx_model, options)
269         if oneutils.is_valid_attr(args, 'keep_io_order'):
270             _remap_io_names(onnx_model)
271             if oneutils.is_valid_attr(args, 'save_intermediate'):
272                 basename = os.path.basename(getattr(args, 'input_path'))
273                 fixed_path = os.path.join(tmpdir,
274                                           os.path.splitext(basename)[0] + '~.onnx')
275                 onnx.save(onnx_model, fixed_path)
276
277         if ext_path:
278             # save onnx_model to temporary alt file
279             basename = os.path.basename(getattr(args, 'input_path'))
280             alt_path = os.path.join(tmpdir, os.path.splitext(basename)[0] + '-alt.onnx')
281             onnx.save(onnx_model, alt_path)
282
283             # call extension with options
284             ext_cmd = [ext_path]
285             if oneutils.is_valid_attr(args, 'unroll_rnn'):
286                 ext_cmd.append('--unroll_rnn')
287             if oneutils.is_valid_attr(args, 'unroll_lstm'):
288                 ext_cmd.append('--unroll_lstm')
289             if oneutils.is_valid_attr(args, 'experimental_disable_batchmatmul_unfold'):
290                 ext_cmd.append('--experimental_disable_batchmatmul_unfold')
291             if oneutils.is_valid_attr(args, 'save_intermediate'):
292                 ext_cmd.append('--save_intermediate')
293             if oneutils.is_valid_attr(args, 'keep_io_order'):
294                 ext_cmd.append('--keep_io_order')
295             ext_cmd.append(alt_path)
296             ext_cmd.append(getattr(args, 'output_path'))
297             oneutils.run(ext_cmd, logfile=f)
298             return
299
300         tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
301
302         savedmodel_name = os.path.splitext(os.path.basename(
303             args.output_path))[0] + '.savedmodel'
304         savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
305         tf_savedmodel.export_graph(savedmodel_output_path)
306
307         # make a command to convert from tf to tflite
308         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
309         tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
310             args.output_path))[0] + '.tflite'
311         tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
312
313         tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
314             args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
315
316         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
317
318         # convert tf to tflite
319         oneutils.run(tf2tfliteV2_cmd, logfile=f)
320
321         # make a command to convert from tflite to circle
322         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
323         tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
324                                                              tf2tfliteV2_output_path,
325                                                              getattr(args, 'output_path'))
326
327         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
328
329         # convert tflite to circle
330         oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
331
332
333 def main():
334     # parse arguments
335     parser = _get_parser()
336     args = _parse_arg(parser)
337
338     # parse configuration file
339     oneutils.parse_cfg(args.config, 'one-import-onnx', args)
340
341     # verify arguments
342     _verify_arg(parser, args)
343
344     # convert
345     _convert(args)
346
347
348 if __name__ == '__main__':
349     oneutils.safemain(main, __file__)