Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / tools / kernel_report / kernel_report.py
1 #!/usr/bin/python
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import os
18 import argparse
19
20
21 class Backend:
22     def __init__(self, backendList):
23         self.backends = {}
24
25         for backend in backendList:
26             self.backends[backend] = False
27
28
29 class KernelReporter(object):
30     def __init__(self, args):
31         # TODO: Remove os defendency - '/'
32         if args.base[0] != '/':
33             self.onertBase = os.getcwd() + '/' + args.base
34         else:
35             self.onertBase = args.base
36         if args.md5:
37             self.printMD5 = True
38         else:
39             self.printMD5 = False
40         self.backendList = args.backends.split(',')
41         self.opListFile = "core/include/ir/Operations.lst"
42         self.operations = []
43         self.kernelGeneratorFile = "KernelGenerator.h"
44         self.kernelMap = {}
45
46     def parseOpList(self):
47         # Parsing line and get op list
48         skipLine = False
49         for line in open(self.onertBase + '/' + self.opListFile, "r"):
50             # Skip license
51             # TODO : Change to skip general comment
52             if skipLine:
53                 if line.startswith(" */"):
54                     skipLine = False
55                     continue
56                 continue
57             if line.startswith("/*"):
58                 skipLine = True
59                 continue
60
61             # Skip comment
62             if line.startswith("//"):
63                 continue
64
65             # Skip macro
66             if line.startswith("#"):
67                 continue
68
69             lineStripped = line.strip()
70             if len(lineStripped) == 0:
71                 continue
72
73             op = lineStripped[3:-1]
74             self.operations.append(op)
75             self.operations.sort()
76
77     def generateKernelMap(self):
78         for op in self.operations:
79             self.kernelMap[op] = Backend(self.backendList)
80
81         for backend in self.backendList:
82             buf = open(
83                 self.onertBase + '/backend/' + backend + '/' + self.kernelGeneratorFile,
84                 "r")
85
86             for line in buf:
87                 words = line.split()
88                 if len(words) < 3:
89                     continue
90                 if words[1] != "visit(const":
91                     continue
92
93                 opName = words[2].split("::")
94                 if len(opName) < 3:
95                     continue
96
97                 if opName[2] in self.operations:
98                     self.kernelMap[opName[2]].backends[backend] = True
99
100             buf.close()
101
102     def printResult(self):
103         print()
104         line = ""
105         for backend in self.backendList:
106             line = line + "{0:^9}".format(backend)
107         print('{0:30}{1}'.format("", line))
108
109         counts = []
110         for i in range(0, len(self.backendList), 1):
111             counts.append(0)
112
113         for op in self.operations:
114             line = ""
115             for i in range(0, len(self.backendList), 1):
116                 support = self.kernelMap[op].backends[self.backendList[i]]
117                 if support:
118                     line = line + "{0:^9}".format("O")
119                     counts[i] += 1
120                 else:
121                     line = line + "{0:^9}".format("-")
122             print('{0:30}{1}'.format(op, line))
123
124         line = ""
125         for count in counts:
126             line = line + "{0:^9}".format(count)
127         print('{0:30}{1}'.format("TOTAL COUNT", line))
128
129     def printMDFormat(self):
130         print()
131         line = "-"
132         for backend in self.backendList:
133             line = line + "|" + backend
134         print(line)
135         line = ""
136         for i in range(0, len(self.backendList), 1):
137             line = line + "-|"
138         print(line + "-")
139
140         counts = []
141         for i in range(0, len(self.backendList), 1):
142             counts.append(0)
143
144         for op in self.operations:
145             line = ""
146             for i in range(0, len(self.backendList), 1):
147                 support = self.kernelMap[op].backends[self.backendList[i]]
148                 if support:
149                     line = line + "|" + "O"
150                     counts[i] += 1
151                 else:
152                     line = line + "|" + "-"
153             print(op + line)
154
155         line = ""
156         for i in range(0, len(self.backendList), 1):
157             line = line + "-|"
158         print(line + "-")
159
160         line = ""
161         for count in counts:
162             line = line + "|" + str(count)
163
164         print("TOTAL COUNT" + line)
165
166     def run(self):
167         self.parseOpList()
168         self.generateKernelMap()
169
170         if self.printMD5:
171             self.printMDFormat()
172         else:
173             self.printResult()
174
175
176 if __name__ == '__main__':
177     arg_parser = argparse.ArgumentParser()
178     arg_parser.add_argument(
179         "--backends",
180         type=str,
181         default='cpu,acl_cl,acl_neon',
182         help="backend list to report (use comma)")
183     arg_parser.add_argument("--md5", action='store_true', help="Print for md5")
184     arg_parser.add_argument("base", type=str, help="onert base directory")
185     args = arg_parser.parse_args()
186
187     report = KernelReporter(args)
188     report.run()