2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
35 import onelib.make_cmd as _make_cmd
36 import onelib.utils as oneutils
38 # TODO Find better way to suppress trackback on error
39 sys.tracebacklimit = 0
42 def get_driver_spec():
43 return ("one-import-pytorch", oneutils.DriverType.IMPORTER)
47 parser = argparse.ArgumentParser(
48 description='command line tool to convert PyTorch to Circle')
50 oneutils.add_default_arg(parser)
52 ## converter arguments
53 converter_group = parser.add_argument_group('converter arguments')
55 # input and output path.
56 converter_group.add_argument(
57 '-i', '--input_path', type=str, help='full filepath of the input file')
58 converter_group.add_argument(
59 '-p', '--python_path', type=str, help='full filepath of the python model file')
60 converter_group.add_argument(
61 '-o', '--output_path', type=str, help='full filepath of the output file')
64 converter_group.add_argument(
69 'Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")'
71 converter_group.add_argument(
75 help='data types of input tensors, colon-separated (ex: float32, uint8, int32)')
78 tf2tflite_group = parser.add_argument_group('tf2tfliteV2 arguments')
79 tf2tflite_group.add_argument('--model_format', default='saved_model')
80 tf2tflite_group.add_argument('--converter_version', default='v2')
82 parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
84 '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
86 # save intermediate file(s)
88 '--save_intermediate',
90 help='Save intermediate files to output folder')
95 def _verify_arg(parser, args):
96 """verify given arguments"""
97 # check if required arguments is given
99 if not oneutils.is_valid_attr(args, 'input_path'):
100 missing.append('-i/--input_path')
101 if not oneutils.is_valid_attr(args, 'output_path'):
102 missing.append('-o/--output_path')
103 if not oneutils.is_valid_attr(args, 'input_shapes'):
104 missing.append('-s/--input_shapes')
105 if not oneutils.is_valid_attr(args, 'input_types'):
106 missing.append('-t/--input_types')
109 parser.error('the following arguments are required: ' + ' '.join(missing))
112 def _parse_arg(parser):
113 args = parser.parse_args()
116 oneutils.print_version_and_exit(__file__)
121 def _apply_verbosity(verbosity):
123 # TF_CPP_MIN_LOG_LEVEL
124 # 0 : INFO + WARNING + ERROR + FATAL
125 # 1 : WARNING + ERROR + FATAL
129 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
131 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
134 def _parse_shapes(shapes_str):
136 for shape_str in shapes_str.split(":"):
138 shapes += [list(map(int, shape_str.split(",")))]
144 def _parse_types(types_str):
145 # There are no convenient way to create torch from string ot numpy dtype, so using this workaround
148 "uint8": torch.uint8,
150 "int16": torch.int16,
151 "int32": torch.int32,
152 "int64": torch.int64,
153 "float16": torch.float16,
154 "float32": torch.float32,
155 "float64": torch.float64,
156 "complex64": torch.complex64,
157 "complex128": torch.complex128
159 array = types_str.split(",")
160 types = [dtype_dict[type_str.strip()] for type_str in array]
164 # merge contents of module into global namespace
165 def _merge_module(module):
166 # is there an __all__? if so respect it
167 if "__all__" in module.__dict__:
168 names = module.__dict__["__all__"]
170 # otherwise we import all names that don't begin with _
171 names = [x for x in module.__dict__ if not x.startswith("_")]
172 globals().update({k: getattr(module, k) for k in names})
175 def _list_classes_from_module(module):
176 # Parsing the module to get all defined classes
177 is_member = lambda member: inspect.isclass(member) and member.__module__ == module.__name__
178 classes = [cls[1] for cls in inspect.getmembers(module, is_member)]
182 def _extract_pytorch_model(log_file, parameters_path, python_path):
183 log_file.write(('Trying to load saved model\n').encode())
184 python_model_path = os.path.abspath(python_path)
185 module_name = os.path.basename(python_model_path)
186 module_dir = os.path.dirname(python_model_path)
187 sys.path.append(module_dir)
188 log_file.write(('Trying to load given python module\n').encode())
189 module_loader = importlib.machinery.SourceFileLoader(module_name, python_model_path)
190 module_spec = importlib.util.spec_from_loader(module_name, module_loader)
191 python_model_module = importlib.util.module_from_spec(module_spec)
194 module_loader.exec_module(python_model_module)
196 raise ValueError('Failed to execute given python model file')
198 log_file.write(('Model python module is loaded\n').encode())
200 # this branch assumes this parameters_path contains state_dict
201 state_dict = torch.load(parameters_path)
202 log_file.write(('Trying to find model class and fill it`s state dict\n').encode())
203 model_class_definitions = _list_classes_from_module(python_model_module)
204 if len(model_class_definitions) != 1:
205 raise ValueError("Expected only one class as model definition. {}".format(
206 model_class_definitions))
207 pytorch_model_class = model_class_definitions[0]
208 model = pytorch_model_class()
209 model.load_state_dict(state_dict)
212 # this branch assumes this parameters_path contains "entire" model
213 _merge_module(python_model_module)
214 log_file.write(('Model python module is merged into main environment\n').encode())
215 model = torch.load(parameters_path)
216 log_file.write(('Pytorch model loaded\n').encode())
220 def _extract_torchscript_model(log_file, input_path):
221 # assuming this is a pytorch script
222 log_file.write(('Trying to load TorchScript model\n').encode())
224 pytorch_model = torch.jit.load(input_path)
226 except RuntimeError as e:
227 log_file.write((str(e) + '\n').encode())
229 'Failed to import input file. Maybe this it contains only weights? Try pass "python_path" argument\n'.
232 log_file.write(('TorchScript model is loaded\n').encode())
235 def _extract_mar_model(log_file, tmpdir, input_path):
236 mar_dir_path = os.path.join(tmpdir, 'mar')
237 with zipfile.ZipFile(input_path) as zip_input:
238 zip_input.extractall(path=mar_dir_path)
239 manifest_path = os.path.join(mar_dir_path, 'MAR-INF/MANIFEST.json')
240 with open(manifest_path) as manifest_file:
241 manifest = json.load(manifest_file)
242 serialized_file = os.path.join(mar_dir_path, manifest['model']['serializedFile'])
243 if 'modelFile' in manifest['model']:
244 model_file = os.path.join(mar_dir_path, manifest['model']['modelFile'])
245 return _extract_pytorch_model(log_file, serialized_file, model_file)
247 return _extract_torchscript_model(log_file, serialized_file)
251 _apply_verbosity(args.verbose)
253 # get file path to log
254 dir_path = os.path.dirname(os.path.realpath(__file__))
255 logfile_path = os.path.realpath(args.output_path) + '.log'
256 with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
258 if oneutils.is_valid_attr(args, 'save_intermediate'):
259 tmpdir = os.path.dirname(logfile_path)
260 # convert pytorch to onnx model
261 input_path = getattr(args, 'input_path')
262 model_file = getattr(args, 'python_path')
264 if input_path[-4:] == '.mar':
265 pytorch_model = _extract_mar_model(f, tmpdir, input_path)
266 elif model_file is None:
267 pytorch_model = _extract_torchscript_model(f, input_path)
269 pytorch_model = _extract_pytorch_model(f, input_path, model_file)
271 input_shapes = _parse_shapes(getattr(args, 'input_shapes'))
272 input_types = _parse_types(getattr(args, 'input_types'))
274 if len(input_shapes) != len(input_types):
275 raise ValueError('number of input shapes and input types must be equal')
278 for input_spec in zip(input_shapes, input_types):
279 sample_inputs += [torch.ones(input_spec[0], dtype=input_spec[1])]
281 f.write(('Trying to inference loaded model').encode())
282 sample_outputs = pytorch_model(*sample_inputs)
283 f.write(('Acquired sample outputs\n').encode())
285 onnx_output_name = os.path.splitext(os.path.basename(
286 args.output_path))[0] + '.onnx'
287 onnx_output_path = os.path.join(tmpdir, onnx_output_name)
290 # some operations are not supported in early opset versions, try several
291 for onnx_opset_version in range(9, 15):
292 f.write(('Trying to save onnx model using opset version ' +
293 str(onnx_opset_version) + '\n').encode())
297 tuple(sample_inputs),
299 example_outputs=sample_outputs,
300 opset_version=onnx_opset_version)
304 f.write(('attempt failed\n').encode())
307 raise ValueError('Failed to save temporary onnx model')
309 # convert onnx to tf saved mode
310 onnx_model = onnx.load(onnx_output_path)
312 options = onnx_legalizer.LegalizeOptions()
313 options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
314 options.unroll_lstm = oneutils.is_valid_attr(args, 'unroll_lstm')
315 onnx_legalizer.legalize(onnx_model, options)
317 tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
319 savedmodel_name = os.path.splitext(os.path.basename(
320 args.output_path))[0] + '.savedmodel'
321 savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
322 tf_savedmodel.export_graph(savedmodel_output_path)
324 # make a command to convert from tf to tflite
325 tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
326 tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
327 args.output_path))[0] + '.tflite'
328 tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
330 del args.input_shapes
331 tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
332 args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
334 f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
336 # convert tf to tflite
337 oneutils.run(tf2tfliteV2_cmd, logfile=f)
339 # make a command to convert from tflite to circle
340 tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
341 tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
342 tf2tfliteV2_output_path,
343 getattr(args, 'output_path'))
345 f.write((' '.join(tflite2circle_cmd) + '\n').encode())
347 # convert tflite to circle
348 oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
353 parser = _get_parser()
354 args = _parse_arg(parser)
356 # parse configuration file
357 oneutils.parse_cfg(args.config, 'one-import-pytorch', args)
360 _verify_arg(parser, args)
366 if __name__ == '__main__':
367 oneutils.safemain(main, __file__)