24edea645b252668c9df7fd80001b5a12f9d80eb
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-onnx
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import sys
25 import tempfile
26 import onnx
27 import onnx_tf
28
29 # ONNX legalizer is an optional feature
30 # It enables conversion of some operations, but in experimental phase for now
31 try:
32     import onnx_legalizer
33     _onnx_legalizer_enabled = True
34 except ImportError:
35     _onnx_legalizer_enabled = False
36
37 import onelib.make_cmd as _make_cmd
38 import onelib.utils as oneutils
39
40 # TODO Find better way to suppress trackback on error
41 sys.tracebacklimit = 0
42
43
44 def get_driver_cfg_section():
45     return "one-import-onnx"
46
47
48 def _get_parser():
49     parser = argparse.ArgumentParser(
50         description='command line tool to convert ONNX to circle')
51
52     oneutils.add_default_arg(parser)
53
54     ## tf2tfliteV2 arguments
55     tf2tfliteV2_group = parser.add_argument_group('converter arguments')
56
57     # input and output path.
58     tf2tfliteV2_group.add_argument(
59         '-i', '--input_path', type=str, help='full filepath of the input file')
60     tf2tfliteV2_group.add_argument(
61         '-o', '--output_path', type=str, help='full filepath of the output file')
62
63     # input and output arrays.
64     tf2tfliteV2_group.add_argument(
65         '-I',
66         '--input_arrays',
67         type=str,
68         help='names of the input arrays, comma-separated')
69     tf2tfliteV2_group.add_argument(
70         '-O',
71         '--output_arrays',
72         type=str,
73         help='names of the output arrays, comma-separated')
74
75     # fixed options
76     tf2tfliteV2_group.add_argument('--model_format', default='saved_model')
77     tf2tfliteV2_group.add_argument('--converter_version', default='v2')
78
79     parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
80     parser.add_argument(
81         '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
82     parser.add_argument(
83         '--keep_io_order',
84         action='store_true',
85         help=
86         'Ensure generated circle model preserves the I/O order of the original onnx model.'
87     )
88
89     # save intermediate file(s)
90     parser.add_argument(
91         '--save_intermediate',
92         action='store_true',
93         help='Save intermediate files to output folder')
94
95     # experimental options
96     parser.add_argument(
97         '--experimental_disable_batchmatmul_unfold',
98         action='store_true',
99         help='Experimental disable BatchMatMul unfold')
100
101     return parser
102
103
104 def _verify_arg(parser, args):
105     """verify given arguments"""
106     # check if required arguments is given
107     missing = []
108     if not oneutils.is_valid_attr(args, 'input_path'):
109         missing.append('-i/--input_path')
110     if not oneutils.is_valid_attr(args, 'output_path'):
111         missing.append('-o/--output_path')
112     if len(missing):
113         parser.error('the following arguments are required: ' + ' '.join(missing))
114
115
116 def _parse_arg(parser):
117     args = parser.parse_args()
118     # print version
119     if args.version:
120         oneutils.print_version_and_exit(__file__)
121
122     return args
123
124
125 def _apply_verbosity(verbosity):
126     # NOTE
127     # TF_CPP_MIN_LOG_LEVEL
128     #   0 : INFO + WARNING + ERROR + FATAL
129     #   1 : WARNING + ERROR + FATAL
130     #   2 : ERROR + FATAL
131     #   3 : FATAL
132     if verbosity:
133         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
134     else:
135         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
136
137
138 # The index of input/output is added in front of the name. For example,
139 # Original input names: 'a', 'c', 'b'
140 # Renamed: '0001_a', '0002_c', '0003_b'
141 # This will preserve I/O order after import.
142 def _remap_io_names(onnx_model):
143     # gather existing name of I/O and generate new name of I/O in sort order
144     input_nodes = []
145     output_nodes = []
146     remap_inputs = []
147     remap_outputs = []
148     initializers = []
149     # some models may have initializers as inputs. ignore them.
150     for initializer in onnx_model.graph.initializer:
151         initializers.append(initializer.name)
152     for idx in range(0, len(onnx_model.graph.input)):
153         name = onnx_model.graph.input[idx].name
154         if not name in initializers:
155             input_nodes.append(name)
156             remap_inputs.append(format(idx + 1, '04d') + '_' + name)
157     for idx in range(0, len(onnx_model.graph.output)):
158         name = onnx_model.graph.output[idx].name
159         output_nodes.append(name)
160         remap_outputs.append(format(idx + 1, '04d') + '_' + name)
161     # change names for graph input
162     for i in range(len(onnx_model.graph.input)):
163         if onnx_model.graph.input[i].name in input_nodes:
164             to_rename = onnx_model.graph.input[i].name
165             idx = input_nodes.index(to_rename)
166             onnx_model.graph.input[i].name = remap_inputs[idx]
167     # change names of all nodes in the graph
168     for i in range(len(onnx_model.graph.node)):
169         # check node.input is to change to remap_inputs or remap_outputs
170         for j in range(len(onnx_model.graph.node[i].input)):
171             if onnx_model.graph.node[i].input[j] in input_nodes:
172                 to_rename = onnx_model.graph.node[i].input[j]
173                 idx = input_nodes.index(to_rename)
174                 onnx_model.graph.node[i].input[j] = remap_inputs[idx]
175             if onnx_model.graph.node[i].input[j] in output_nodes:
176                 to_rename = onnx_model.graph.node[i].input[j]
177                 idx = output_nodes.index(to_rename)
178                 onnx_model.graph.node[i].input[j] = remap_outputs[idx]
179         # check node.output is to change to remap_inputs or remap_outputs
180         for j in range(len(onnx_model.graph.node[i].output)):
181             if onnx_model.graph.node[i].output[j] in output_nodes:
182                 to_rename = onnx_model.graph.node[i].output[j]
183                 idx = output_nodes.index(to_rename)
184                 onnx_model.graph.node[i].output[j] = remap_outputs[idx]
185             if onnx_model.graph.node[i].output[j] in input_nodes:
186                 to_rename = onnx_model.graph.node[i].output[j]
187                 idx = input_nodes.index(to_rename)
188                 onnx_model.graph.node[i].output[j] = remap_inputs[idx]
189     # change names for graph output
190     for i in range(len(onnx_model.graph.output)):
191         if onnx_model.graph.output[i].name in output_nodes:
192             to_rename = onnx_model.graph.output[i].name
193             idx = output_nodes.index(to_rename)
194             onnx_model.graph.output[i].name = remap_outputs[idx]
195
196
197 def _convert(args):
198     _apply_verbosity(args.verbose)
199
200     # get file path to log
201     dir_path = os.path.dirname(os.path.realpath(__file__))
202     logfile_path = os.path.realpath(args.output_path) + '.log'
203
204     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
205         # save intermediate
206         if oneutils.is_valid_attr(args, 'save_intermediate'):
207             tmpdir = os.path.dirname(logfile_path)
208         # convert onnx to tf saved model
209         onnx_model = onnx.load(getattr(args, 'input_path'))
210         if _onnx_legalizer_enabled:
211             options = onnx_legalizer.LegalizeOptions
212             options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
213             options.unroll_lstm = oneutils.is_valid_attr(args, 'unroll_lstm')
214             onnx_legalizer.legalize(onnx_model, options)
215         if oneutils.is_valid_attr(args, 'keep_io_order'):
216             _remap_io_names(onnx_model)
217             if oneutils.is_valid_attr(args, 'save_intermediate'):
218                 basename = os.path.basename(getattr(args, 'input_path'))
219                 fixed_path = os.path.join(tmpdir,
220                                           os.path.splitext(basename)[0] + '~.onnx')
221                 onnx.save(onnx_model, fixed_path)
222         tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
223
224         savedmodel_name = os.path.splitext(os.path.basename(
225             args.output_path))[0] + '.savedmodel'
226         savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
227         tf_savedmodel.export_graph(savedmodel_output_path)
228
229         # make a command to convert from tf to tflite
230         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
231         tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
232             args.output_path))[0] + '.tflite'
233         tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
234
235         tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
236             args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
237
238         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
239
240         # convert tf to tflite
241         oneutils.run(tf2tfliteV2_cmd, logfile=f)
242
243         # make a command to convert from tflite to circle
244         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
245         tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
246                                                              tf2tfliteV2_output_path,
247                                                              getattr(args, 'output_path'))
248
249         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
250
251         # convert tflite to circle
252         oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
253
254
255 def main():
256     # parse arguments
257     parser = _get_parser()
258     args = _parse_arg(parser)
259
260     # parse configuration file
261     oneutils.parse_cfg(args.config, 'one-import-onnx', args)
262
263     # verify arguments
264     _verify_arg(parser, args)
265
266     # convert
267     _convert(args)
268
269
270 if __name__ == '__main__':
271     oneutils.safemain(main, __file__)