3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 # PyTorch Example manager
22 import tensorflow as tf
26 from pathlib import Path
28 print("PyTorch version=", torch.__version__)
29 print("ONNX version=", onnx.__version__)
30 print("ONNX-TF version=", onnx_tf.__version__)
31 print("TF version=", tf.__version__)
33 parser = argparse.ArgumentParser(description='Process PyTorch python examples')
35 parser.add_argument('examples', metavar='EXAMPLES', nargs='+')
37 args = parser.parse_args()
39 output_folder = "./output/"
41 Path(output_folder).mkdir(parents=True, exist_ok=True)
43 for example in args.examples:
45 module = importlib.import_module("examples." + example)
48 torch.save(module._model_, output_folder + example + ".pth")
49 print("Generate '" + example + ".pth' - Done")
52 if hasattr(module._model_, 'onnx_opset_version'):
53 opset_version = module._model_.onnx_opset_version()
55 onnx_model_path = output_folder + example + ".onnx"
58 module._model_, module._dummy_, onnx_model_path, opset_version=opset_version)
59 print("Generate '" + example + ".onnx' - Done")
61 onnx_model = onnx.load(onnx_model_path)
62 onnx.checker.check_model(onnx_model)
64 inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
65 onnx.checker.check_model(inferred_model)
66 onnx.save(inferred_model, onnx_model_path)
68 tf_prep = onnx_tf.backend.prepare(inferred_model)
69 tf_prep.export_graph(path=output_folder + example + ".TF")
70 print("Generate '" + example + " TF' - Done")
73 converter = tf.lite.TFLiteConverter.from_saved_model(output_folder + example + ".TF")
74 converter.allow_custom_ops = True
75 converter.experimental_new_converter = True
76 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
78 tflite_model = converter.convert()
79 open(output_folder + example + ".tflite", "wb").write(tflite_model)
80 print("Generate '" + example + ".tflite' - Done")