3 # Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 # PyTorch Example manager
24 from pathlib import Path
26 print("PyTorch version=", torch.__version__)
28 parser = argparse.ArgumentParser(description='Process PyTorch python examples')
30 parser.add_argument('examples', metavar='EXAMPLES', nargs='+')
32 args = parser.parse_args()
36 Path(output_folder).mkdir(parents=True, exist_ok=True)
39 class JitWrapper(torch.nn.Module):
40 def __init__(self, model):
44 def forward(self, *args):
46 return self.model.forward(args[0])
48 return self.model.forward(args)
51 for example in args.examples:
52 print("Generate '" + example + ".pth'", end='')
54 # replace - with _ in name, otherwise pytorch generates invalid torchscript
55 module_name = "examples." + example.replace('-', '_')
56 module_loader = importlib.machinery.SourceFileLoader(
57 module_name, os.path.join("examples", example, "__init__.py"))
58 module_spec = importlib.util.spec_from_loader(module_name, module_loader)
59 module = importlib.util.module_from_spec(module_spec)
60 module_loader.exec_module(module)
62 jittable_model = JitWrapper(module._model_)
64 traced_model = torch.jit.trace(jittable_model, module._dummy_)
66 torch.jit.save(traced_model, output_folder + example + ".pth")
71 input_samples = module._dummy_
72 if isinstance(input_samples, torch.Tensor):
73 input_samples = [input_samples]
74 for inp_idx in range(len(input_samples)):
75 input_data = input_samples[inp_idx]
77 shape = input_data.shape
78 for dim in range(len(shape)):
79 input_shapes += str(shape[dim])
80 if dim != len(shape) - 1:
83 if input_data.dtype == torch.bool:
85 elif input_data.dtype == torch.uint8:
86 input_types += "uint8"
87 elif input_data.dtype == torch.int8:
89 elif input_data.dtype == torch.int16:
90 input_types += "int16"
91 elif input_data.dtype == torch.int32:
92 input_types += "int32"
93 elif input_data.dtype == torch.int64:
94 input_types += "int16"
95 elif input_data.dtype == torch.float16:
96 input_types += "float32"
97 elif input_data.dtype == torch.float32:
98 input_types += "float32"
99 elif input_data.dtype == torch.float64:
100 input_types += "float64"
101 elif input_data.dtype == torch.complex64:
102 input_types += "complex64"
103 elif input_data.dtype == torch.complex128:
104 input_types += "complex128"
106 raise ValueError('unsupported dtype')
108 if inp_idx != len(input_samples) - 1:
112 with open(example + ".spec", "w") as spec_file:
113 print(input_shapes, file=spec_file)
114 print(input_types, file=spec_file)