3a61d22fc6c4b3fcc517cbddc524d3b107214e2d
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-pytorch
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import importlib
24 import inspect
25 import os
26 import sys
27 import tempfile
28 import torch
29 import onnx
30 import onnx_tf
31 import json
32 import zipfile
33
34 import onnx_legalizer
35 import onelib.make_cmd as _make_cmd
36 import onelib.utils as oneutils
37
38 # TODO Find better way to suppress trackback on error
39 sys.tracebacklimit = 0
40
41
42 def get_driver_spec():
43     return ("one-import-pytorch", oneutils.DriverType.IMPORTER)
44
45
46 def _get_parser():
47     parser = argparse.ArgumentParser(
48         description='command line tool to convert PyTorch to Circle')
49
50     oneutils.add_default_arg(parser)
51
52     ## converter arguments
53     converter_group = parser.add_argument_group('converter arguments')
54
55     # input and output path.
56     converter_group.add_argument(
57         '-i', '--input_path', type=str, help='full filepath of the input file')
58     converter_group.add_argument(
59         '-p', '--python_path', type=str, help='full filepath of the python model file')
60     converter_group.add_argument(
61         '-o', '--output_path', type=str, help='full filepath of the output file')
62
63     # input arrays.
64     converter_group.add_argument(
65         '-s',
66         '--input_shapes',
67         type=str,
68         help=
69         'Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")'
70     )
71     converter_group.add_argument(
72         '-t',
73         '--input_types',
74         type=str,
75         help='data types of input tensors, colon-separated (ex: float32, uint8, int32)')
76
77     # fixed options
78     tf2tflite_group = parser.add_argument_group('tf2tfliteV2 arguments')
79     tf2tflite_group.add_argument('--model_format', default='saved_model')
80     tf2tflite_group.add_argument('--converter_version', default='v2')
81
82     parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
83     parser.add_argument(
84         '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
85
86     # save intermediate file(s)
87     parser.add_argument(
88         '--save_intermediate',
89         action='store_true',
90         help='Save intermediate files to output folder')
91
92     return parser
93
94
95 def _verify_arg(parser, args):
96     """verify given arguments"""
97     # check if required arguments is given
98     missing = []
99     if not oneutils.is_valid_attr(args, 'input_path'):
100         missing.append('-i/--input_path')
101     if not oneutils.is_valid_attr(args, 'output_path'):
102         missing.append('-o/--output_path')
103     if not oneutils.is_valid_attr(args, 'input_shapes'):
104         missing.append('-s/--input_shapes')
105     if not oneutils.is_valid_attr(args, 'input_types'):
106         missing.append('-t/--input_types')
107
108     if len(missing):
109         parser.error('the following arguments are required: ' + ' '.join(missing))
110
111
112 def _parse_arg(parser):
113     args = parser.parse_args()
114     # print version
115     if args.version:
116         oneutils.print_version_and_exit(__file__)
117
118     return args
119
120
121 def _apply_verbosity(verbosity):
122     # NOTE
123     # TF_CPP_MIN_LOG_LEVEL
124     #   0 : INFO + WARNING + ERROR + FATAL
125     #   1 : WARNING + ERROR + FATAL
126     #   2 : ERROR + FATAL
127     #   3 : FATAL
128     if verbosity:
129         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
130     else:
131         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
132
133
134 def _parse_shapes(shapes_str):
135     shapes = []
136     for shape_str in shapes_str.split(":"):
137         if shape_str != "":
138             shapes += [list(map(int, shape_str.split(",")))]
139         else:
140             shapes += [[]]
141     return shapes
142
143
144 def _parse_types(types_str):
145     # There are no convenient way to create torch from string ot numpy dtype, so using this workaround
146     dtype_dict = {
147         "bool": torch.bool,
148         "uint8": torch.uint8,
149         "int8": torch.int8,
150         "int16": torch.int16,
151         "int32": torch.int32,
152         "int64": torch.int64,
153         "float16": torch.float16,
154         "float32": torch.float32,
155         "float64": torch.float64,
156         "complex64": torch.complex64,
157         "complex128": torch.complex128
158     }
159     array = types_str.split(",")
160     types = [dtype_dict[type_str.strip()] for type_str in array]
161     return types
162
163
164 # merge contents of module into global namespace
165 def _merge_module(module):
166     # is there an __all__?  if so respect it
167     if "__all__" in module.__dict__:
168         names = module.__dict__["__all__"]
169     else:
170         # otherwise we import all names that don't begin with _
171         names = [x for x in module.__dict__ if not x.startswith("_")]
172     globals().update({k: getattr(module, k) for k in names})
173
174
175 def _list_classes_from_module(module):
176     # Parsing the module to get all defined classes
177     is_member = lambda member: inspect.isclass(member) and member.__module__ == module.__name__
178     classes = [cls[1] for cls in inspect.getmembers(module, is_member)]
179     return classes
180
181
182 def _extract_pytorch_model(log_file, parameters_path, python_path):
183     log_file.write(('Trying to load saved model\n').encode())
184     python_model_path = os.path.abspath(python_path)
185     module_name = os.path.basename(python_model_path)
186     module_dir = os.path.dirname(python_model_path)
187     sys.path.append(module_dir)
188     log_file.write(('Trying to load given python module\n').encode())
189     module_loader = importlib.machinery.SourceFileLoader(module_name, python_model_path)
190     module_spec = importlib.util.spec_from_loader(module_name, module_loader)
191     python_model_module = importlib.util.module_from_spec(module_spec)
192
193     try:
194         module_loader.exec_module(python_model_module)
195     except:
196         raise ValueError('Failed to execute given python model file')
197
198     log_file.write(('Model python module is loaded\n').encode())
199     try:
200         # this branch assumes this parameters_path contains state_dict
201         state_dict = torch.load(parameters_path)
202         log_file.write(('Trying to find model class and fill it`s state dict\n').encode())
203         model_class_definitions = _list_classes_from_module(python_model_module)
204         if len(model_class_definitions) != 1:
205             raise ValueError("Expected only one class as model definition. {}".format(
206                 model_class_definitions))
207         pytorch_model_class = model_class_definitions[0]
208         model = pytorch_model_class()
209         model.load_state_dict(state_dict)
210         return model
211     except:
212         # this branch assumes this parameters_path contains "entire" model
213         _merge_module(python_model_module)
214         log_file.write(('Model python module is merged into main environment\n').encode())
215         model = torch.load(parameters_path)
216         log_file.write(('Pytorch model loaded\n').encode())
217         return model
218
219
220 def _extract_torchscript_model(log_file, input_path):
221     # assuming this is a pytorch script
222     log_file.write(('Trying to load TorchScript model\n').encode())
223     try:
224         pytorch_model = torch.jit.load(input_path)
225         return pytorch_model
226     except RuntimeError as e:
227         log_file.write((str(e) + '\n').encode())
228         log_file.write(
229             'Failed to import input file. Maybe this it contains only weights? Try pass "python_path" argument\n'.
230             encode())
231         raise
232     log_file.write(('TorchScript model is loaded\n').encode())
233
234
235 def _extract_mar_model(log_file, tmpdir, input_path):
236     mar_dir_path = os.path.join(tmpdir, 'mar')
237     with zipfile.ZipFile(input_path) as zip_input:
238         zip_input.extractall(path=mar_dir_path)
239     manifest_path = os.path.join(mar_dir_path, 'MAR-INF/MANIFEST.json')
240     with open(manifest_path) as manifest_file:
241         manifest = json.load(manifest_file)
242     serialized_file = os.path.join(mar_dir_path, manifest['model']['serializedFile'])
243     if 'modelFile' in manifest['model']:
244         model_file = os.path.join(mar_dir_path, manifest['model']['modelFile'])
245         return _extract_pytorch_model(log_file, serialized_file, model_file)
246     else:
247         return _extract_torchscript_model(log_file, serialized_file)
248
249
250 def _convert(args):
251     _apply_verbosity(args.verbose)
252
253     # get file path to log
254     dir_path = os.path.dirname(os.path.realpath(__file__))
255     logfile_path = os.path.realpath(args.output_path) + '.log'
256     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
257         # save intermediate
258         if oneutils.is_valid_attr(args, 'save_intermediate'):
259             tmpdir = os.path.dirname(logfile_path)
260         # convert pytorch to onnx model
261         input_path = getattr(args, 'input_path')
262         model_file = getattr(args, 'python_path')
263
264         if input_path[-4:] == '.mar':
265             pytorch_model = _extract_mar_model(f, tmpdir, input_path)
266         elif model_file is None:
267             pytorch_model = _extract_torchscript_model(f, input_path)
268         else:
269             pytorch_model = _extract_pytorch_model(f, input_path, model_file)
270
271         input_shapes = _parse_shapes(getattr(args, 'input_shapes'))
272         input_types = _parse_types(getattr(args, 'input_types'))
273
274         if len(input_shapes) != len(input_types):
275             raise ValueError('number of input shapes and input types must be equal')
276
277         sample_inputs = []
278         for input_spec in zip(input_shapes, input_types):
279             sample_inputs += [torch.ones(input_spec[0], dtype=input_spec[1])]
280
281         f.write(('Trying to inference loaded model').encode())
282         sample_outputs = pytorch_model(*sample_inputs)
283         f.write(('Acquired sample outputs\n').encode())
284
285         onnx_output_name = os.path.splitext(os.path.basename(
286             args.output_path))[0] + '.onnx'
287         onnx_output_path = os.path.join(tmpdir, onnx_output_name)
288
289         onnx_saved = False
290         # some operations are not supported in early opset versions, try several
291         for onnx_opset_version in range(9, 15):
292             f.write(('Trying to save onnx model using opset version ' +
293                      str(onnx_opset_version) + '\n').encode())
294             try:
295                 torch.onnx.export(
296                     pytorch_model,
297                     tuple(sample_inputs),
298                     onnx_output_path,
299                     example_outputs=sample_outputs,
300                     opset_version=onnx_opset_version)
301                 onnx_saved = True
302                 break
303             except:
304                 f.write(('attempt failed\n').encode())
305
306         if not onnx_saved:
307             raise ValueError('Failed to save temporary onnx model')
308
309         # convert onnx to tf saved mode
310         onnx_model = onnx.load(onnx_output_path)
311
312         options = onnx_legalizer.LegalizeOptions()
313         options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
314         options.unroll_lstm = oneutils.is_valid_attr(args, 'unroll_lstm')
315         onnx_legalizer.legalize(onnx_model, options)
316
317         tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
318
319         savedmodel_name = os.path.splitext(os.path.basename(
320             args.output_path))[0] + '.savedmodel'
321         savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
322         tf_savedmodel.export_graph(savedmodel_output_path)
323
324         # make a command to convert from tf to tflite
325         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
326         tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
327             args.output_path))[0] + '.tflite'
328         tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
329
330         del args.input_shapes
331         tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
332             args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
333
334         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
335
336         # convert tf to tflite
337         oneutils.run(tf2tfliteV2_cmd, logfile=f)
338
339         # make a command to convert from tflite to circle
340         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
341         tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
342                                                              tf2tfliteV2_output_path,
343                                                              getattr(args, 'output_path'))
344
345         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
346
347         # convert tflite to circle
348         oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
349
350
351 def main():
352     # parse arguments
353     parser = _get_parser()
354     args = _parse_arg(parser)
355
356     # parse configuration file
357     oneutils.parse_cfg(args.config, 'one-import-pytorch', args)
358
359     # verify arguments
360     _verify_arg(parser, args)
361
362     # convert
363     _convert(args)
364
365
366 if __name__ == '__main__':
367     oneutils.safemain(main, __file__)