20a80c895427e09abea67bd813bab3a5c3137d88
[platform/core/ml/nnfw.git] / compiler / one-cmds / tests / pytorch-operations / example_generator.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 # PyTorch Example manager
18
19 import torch
20 import importlib
21 import argparse
22 import os
23
24 from pathlib import Path
25
26 print("PyTorch version=", torch.__version__)
27
28 parser = argparse.ArgumentParser(description='Process PyTorch python examples')
29
30 parser.add_argument('examples', metavar='EXAMPLES', nargs='+')
31
32 args = parser.parse_args()
33
34 output_folder = "./"
35
36 Path(output_folder).mkdir(parents=True, exist_ok=True)
37
38
39 class JitWrapper(torch.nn.Module):
40     def __init__(self, model):
41         super().__init__()
42         self.model = model
43
44     def forward(self, *args):
45         if len(args) == 1:
46             return self.model.forward(args[0])
47         else:
48             return self.model.forward(args)
49
50
51 for example in args.examples:
52     print("Generate '" + example + ".pth'", end='')
53     # load example code
54     # replace - with _ in name, otherwise pytorch generates invalid torchscript
55     module_name = "examples." + example.replace('-', '_')
56     module_loader = importlib.machinery.SourceFileLoader(
57         module_name, os.path.join("examples", example, "__init__.py"))
58     module_spec = importlib.util.spec_from_loader(module_name, module_loader)
59     module = importlib.util.module_from_spec(module_spec)
60     module_loader.exec_module(module)
61
62     jittable_model = JitWrapper(module._model_)
63
64     traced_model = torch.jit.trace(jittable_model, module._dummy_)
65     # save .pth
66     torch.jit.save(traced_model, output_folder + example + ".pth")
67
68     input_shapes = ""
69     input_types = ""
70
71     input_samples = module._dummy_
72     if isinstance(input_samples, torch.Tensor):
73         input_samples = [input_samples]
74     for inp_idx in range(len(input_samples)):
75         input_data = input_samples[inp_idx]
76
77         shape = input_data.shape
78         for dim in range(len(shape)):
79             input_shapes += str(shape[dim])
80             if dim != len(shape) - 1:
81                 input_shapes += ","
82
83         if input_data.dtype == torch.bool:
84             input_types += "bool"
85         elif input_data.dtype == torch.uint8:
86             input_types += "uint8"
87         elif input_data.dtype == torch.int8:
88             input_types += "int8"
89         elif input_data.dtype == torch.int16:
90             input_types += "int16"
91         elif input_data.dtype == torch.int32:
92             input_types += "int32"
93         elif input_data.dtype == torch.int64:
94             input_types += "int16"
95         elif input_data.dtype == torch.float16:
96             input_types += "float32"
97         elif input_data.dtype == torch.float32:
98             input_types += "float32"
99         elif input_data.dtype == torch.float64:
100             input_types += "float64"
101         elif input_data.dtype == torch.complex64:
102             input_types += "complex64"
103         elif input_data.dtype == torch.complex128:
104             input_types += "complex128"
105         else:
106             raise ValueError('unsupported dtype')
107
108         if inp_idx != len(input_samples) - 1:
109             input_shapes += ":"
110             input_types += ","
111
112     with open(example + ".spec", "w") as spec_file:
113         print(input_shapes, file=spec_file)
114         print(input_types, file=spec_file)
115
116     print(" - Done")