Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / res / PyTorchExamples / ptem.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 # PyTorch Example manager
18
19 import torch
20 import onnx
21 import onnx_tf
22 import tensorflow as tf
23 import importlib
24 import argparse
25
26 from pathlib import Path
27
28 print("PyTorch version=", torch.__version__)
29 print("ONNX version=", onnx.__version__)
30 print("ONNX-TF version=", onnx_tf.__version__)
31 print("TF version=", tf.__version__)
32
33 parser = argparse.ArgumentParser(description='Process PyTorch python examples')
34
35 parser.add_argument('examples', metavar='EXAMPLES', nargs='+')
36
37 args = parser.parse_args()
38
39 output_folder = "./output/"
40
41 Path(output_folder).mkdir(parents=True, exist_ok=True)
42
43 for example in args.examples:
44     # load example code
45     module = importlib.import_module("examples." + example)
46
47     # save .pth
48     torch.save(module._model_, output_folder + example + ".pth")
49     print("Generate '" + example + ".pth' - Done")
50
51     opset_version = 9
52     if hasattr(module._model_, 'onnx_opset_version'):
53         opset_version = module._model_.onnx_opset_version()
54
55     onnx_model_path = output_folder + example + ".onnx"
56
57     torch.onnx.export(
58         module._model_, module._dummy_, onnx_model_path, opset_version=opset_version)
59     print("Generate '" + example + ".onnx' - Done")
60
61     onnx_model = onnx.load(onnx_model_path)
62     onnx.checker.check_model(onnx_model)
63
64     inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
65     onnx.checker.check_model(inferred_model)
66     onnx.save(inferred_model, onnx_model_path)
67
68     tf_prep = onnx_tf.backend.prepare(inferred_model)
69     tf_prep.export_graph(path=output_folder + example + ".TF")
70     print("Generate '" + example + " TF' - Done")
71
72     # for testing...
73     converter = tf.lite.TFLiteConverter.from_saved_model(output_folder + example + ".TF")
74     converter.allow_custom_ops = True
75     converter.experimental_new_converter = True
76     converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
77
78     tflite_model = converter.convert()
79     open(output_folder + example + ".tflite", "wb").write(tflite_model)
80     print("Generate '" + example + ".tflite' - Done")