2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
29 # ONNX legalizer is an optional feature
30 # It enables conversion of some operations, but in experimental phase for now
33 _onnx_legalizer_enabled = True
35 _onnx_legalizer_enabled = False
37 import onelib.make_cmd as _make_cmd
38 import onelib.utils as oneutils
40 # TODO Find better way to suppress trackback on error
41 sys.tracebacklimit = 0
44 def get_driver_cfg_section():
45 return "one-import-onnx"
49 parser = argparse.ArgumentParser(
50 description='command line tool to convert ONNX to circle')
52 oneutils.add_default_arg(parser)
54 ## tf2tfliteV2 arguments
55 tf2tfliteV2_group = parser.add_argument_group('converter arguments')
57 # input and output path.
58 tf2tfliteV2_group.add_argument(
59 '-i', '--input_path', type=str, help='full filepath of the input file')
60 tf2tfliteV2_group.add_argument(
61 '-o', '--output_path', type=str, help='full filepath of the output file')
63 # input and output arrays.
64 tf2tfliteV2_group.add_argument(
68 help='names of the input arrays, comma-separated')
69 tf2tfliteV2_group.add_argument(
73 help='names of the output arrays, comma-separated')
76 tf2tfliteV2_group.add_argument('--model_format', default='saved_model')
77 tf2tfliteV2_group.add_argument('--converter_version', default='v2')
79 parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
81 '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
86 'Ensure generated circle model preserves the I/O order of the original onnx model.'
89 # save intermediate file(s)
91 '--save_intermediate',
93 help='Save intermediate files to output folder')
95 # experimental options
97 '--experimental_disable_batchmatmul_unfold',
99 help='Experimental disable BatchMatMul unfold')
104 def _verify_arg(parser, args):
105 """verify given arguments"""
106 # check if required arguments is given
108 if not oneutils.is_valid_attr(args, 'input_path'):
109 missing.append('-i/--input_path')
110 if not oneutils.is_valid_attr(args, 'output_path'):
111 missing.append('-o/--output_path')
113 parser.error('the following arguments are required: ' + ' '.join(missing))
116 def _parse_arg(parser):
117 args = parser.parse_args()
120 oneutils.print_version_and_exit(__file__)
125 def _apply_verbosity(verbosity):
127 # TF_CPP_MIN_LOG_LEVEL
128 # 0 : INFO + WARNING + ERROR + FATAL
129 # 1 : WARNING + ERROR + FATAL
133 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
135 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
138 # The index of input/output is added in front of the name. For example,
139 # Original input names: 'a', 'c', 'b'
140 # Renamed: '0001_a', '0002_c', '0003_b'
141 # This will preserve I/O order after import.
142 def _remap_io_names(onnx_model):
143 # gather existing name of I/O and generate new name of I/O in sort order
149 # some models may have initializers as inputs. ignore them.
150 for initializer in onnx_model.graph.initializer:
151 initializers.append(initializer.name)
152 for idx in range(0, len(onnx_model.graph.input)):
153 name = onnx_model.graph.input[idx].name
154 if not name in initializers:
155 input_nodes.append(name)
156 remap_inputs.append(format(idx + 1, '04d') + '_' + name)
157 for idx in range(0, len(onnx_model.graph.output)):
158 name = onnx_model.graph.output[idx].name
159 output_nodes.append(name)
160 remap_outputs.append(format(idx + 1, '04d') + '_' + name)
161 # change names for graph input
162 for i in range(len(onnx_model.graph.input)):
163 if onnx_model.graph.input[i].name in input_nodes:
164 to_rename = onnx_model.graph.input[i].name
165 idx = input_nodes.index(to_rename)
166 onnx_model.graph.input[i].name = remap_inputs[idx]
167 # change names of all nodes in the graph
168 for i in range(len(onnx_model.graph.node)):
169 # check node.input is to change to remap_inputs or remap_outputs
170 for j in range(len(onnx_model.graph.node[i].input)):
171 if onnx_model.graph.node[i].input[j] in input_nodes:
172 to_rename = onnx_model.graph.node[i].input[j]
173 idx = input_nodes.index(to_rename)
174 onnx_model.graph.node[i].input[j] = remap_inputs[idx]
175 if onnx_model.graph.node[i].input[j] in output_nodes:
176 to_rename = onnx_model.graph.node[i].input[j]
177 idx = output_nodes.index(to_rename)
178 onnx_model.graph.node[i].input[j] = remap_outputs[idx]
179 # check node.output is to change to remap_inputs or remap_outputs
180 for j in range(len(onnx_model.graph.node[i].output)):
181 if onnx_model.graph.node[i].output[j] in output_nodes:
182 to_rename = onnx_model.graph.node[i].output[j]
183 idx = output_nodes.index(to_rename)
184 onnx_model.graph.node[i].output[j] = remap_outputs[idx]
185 if onnx_model.graph.node[i].output[j] in input_nodes:
186 to_rename = onnx_model.graph.node[i].output[j]
187 idx = input_nodes.index(to_rename)
188 onnx_model.graph.node[i].output[j] = remap_inputs[idx]
189 # change names for graph output
190 for i in range(len(onnx_model.graph.output)):
191 if onnx_model.graph.output[i].name in output_nodes:
192 to_rename = onnx_model.graph.output[i].name
193 idx = output_nodes.index(to_rename)
194 onnx_model.graph.output[i].name = remap_outputs[idx]
198 _apply_verbosity(args.verbose)
200 # get file path to log
201 dir_path = os.path.dirname(os.path.realpath(__file__))
202 logfile_path = os.path.realpath(args.output_path) + '.log'
204 with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
206 if oneutils.is_valid_attr(args, 'save_intermediate'):
207 tmpdir = os.path.dirname(logfile_path)
208 # convert onnx to tf saved model
209 onnx_model = onnx.load(getattr(args, 'input_path'))
210 if _onnx_legalizer_enabled:
211 options = onnx_legalizer.LegalizeOptions
212 options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
213 options.unroll_lstm = oneutils.is_valid_attr(args, 'unroll_lstm')
214 onnx_legalizer.legalize(onnx_model, options)
215 if oneutils.is_valid_attr(args, 'keep_io_order'):
216 _remap_io_names(onnx_model)
217 if oneutils.is_valid_attr(args, 'save_intermediate'):
218 basename = os.path.basename(getattr(args, 'input_path'))
219 fixed_path = os.path.join(tmpdir,
220 os.path.splitext(basename)[0] + '~.onnx')
221 onnx.save(onnx_model, fixed_path)
222 tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
224 savedmodel_name = os.path.splitext(os.path.basename(
225 args.output_path))[0] + '.savedmodel'
226 savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
227 tf_savedmodel.export_graph(savedmodel_output_path)
229 # make a command to convert from tf to tflite
230 tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
231 tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
232 args.output_path))[0] + '.tflite'
233 tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
235 tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
236 args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
238 f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
240 # convert tf to tflite
241 oneutils.run(tf2tfliteV2_cmd, logfile=f)
243 # make a command to convert from tflite to circle
244 tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
245 tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
246 tf2tfliteV2_output_path,
247 getattr(args, 'output_path'))
249 f.write((' '.join(tflite2circle_cmd) + '\n').encode())
251 # convert tflite to circle
252 oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
257 parser = _get_parser()
258 args = _parse_arg(parser)
260 # parse configuration file
261 oneutils.parse_cfg(args.config, 'one-import-onnx', args)
264 _verify_arg(parser, args)
270 if __name__ == '__main__':
271 oneutils.safemain(main, __file__)