v0.5.0
[platform/upstream/caffeonacl.git] / tools / extra / tpi.py
1 import sys
2 import os
3 import re
4 import pdb
5 import xlwt
6  
7 help_ = '''
8 Usage:
9     python tpi.py log.txt
10 '''
11
12 #data_list= {}
13 data_list1= []
14 data_list2= []
15 cnt=0
16 times=0.0
17 table_val=''
18 name_list1= ['allocate','run','configure','tensor_copy','ACL_CONV','ACL_FC','ACL_LRN','ACL_POOLING','ACL_RELU','ACL_SOFTMAX']
19
20
21
22
23 def getvalpairs(words):
24     name=''
25     val=''
26     for word in words:
27         if word=='':
28             continue
29         if name=='':
30             name=word
31         else:
32             val=word
33             break;
34         #print word,
35     #print ''
36     return (name,val)
37
38 def addpairstolist(db,name,val,idx):
39      #pdb.set_trace()
40      #if idx in db:
41      #   db[idx]['val'] += val
42      #else:
43      #   db[idx] = {'val':val,'name':name}
44
45      #pdb.set_trace()
46
47      for i in db:
48          if i['name']==name:
49              i['val'] += val
50              return
51      db.append({'idx':idx,'val':val,'name':name})
52
53 def gettabnum(line):
54     start=line.find(':')
55     if start==-1:
56         start=0
57     else:
58         start+=1
59     #pdb.set_trace()
60     str=line[start:-1].lstrip(' ')
61     words=re.split('\t',str)
62     idx=0
63     for word in words:
64         idx+=1
65         if word=='':
66             continue
67         break
68     return idx
69
70 def decodefile(logfile):
71     data_list=data_list1
72     for line in open(logfile):
73         if line.find(':')==-1:
74             continue
75         #pdb.set_trace()
76         #print line,
77         idx=gettabnum(line)
78         words=re.split('\t|:| |\r|\n',line)
79         #print(words)
80         (name,val)=getvalpairs(words)
81         #print (name,float(val),eval(val))
82         if name == 'used' and val == 'time':
83             data_list=data_list2
84         try:
85             addpairstolist(data_list,name,float(val),idx)
86         except ValueError as e:
87             #print(line)
88             continue
89
90 def printresult(db):
91     #for i in db:
92     #    print i, db[i]['idx'],db[i]['val']
93     #pdb.set_trace()
94     db.sort(key=lambda obj:obj.get('idx'), reverse=False)
95     tpi_start=0
96     conv_str='ACL_CONV'
97     find_acl = 0
98     name_index=0
99     global trow
100     global tcol
101     for i in db:
102         if i['name']==conv_str:
103             tpi_start=i['idx']
104
105     tpi=0
106     for i in db:
107         if i['idx']>=tpi_start:
108             tpi+=i['val']
109
110     start=len('ACL_')
111
112     table_head='TPI'+'\t'
113     table_val='%.4f' % (tpi/times)+'\t'
114
115     for i in db:
116         #print i
117         if i['idx']<tpi_start:
118             if i['name'].find('ACL_')==0:
119                table_head+=i['name'][start:]+'\t'
120             else:
121                 table_head+=i['name']+'\t'
122             table_val+='%.4f' % (i['val']/times)+'\t'
123
124     print(table_head)
125     print(table_val)
126
127     table_head='TPI'+'\t'
128     table_val='%.4f' % (tpi/times)+'\t'
129
130     for i in db:
131         if i['idx']>=tpi_start:
132             if i['name'].find('ACL_')==0:
133                #pdb.set_trace()
134                table_head+=i['name'][start:]+'\t'
135             else:
136                 table_head+=i['name']+'\t'
137             table_val+='%.4f' % (i['val']/times)+'\t'
138
139     print(table_head)
140     print(table_val)
141
142     ws.write(trow, tcol, 'TPI')
143     ws.write(trow+1,tcol,'%.4f' % (tpi/times))
144     tcol+=1
145
146     temp_row=trow
147     temp_col=tcol
148     for i in name_list1:
149         if i.find('ACL_')==0 and find_acl==0:
150             temp_row+=2
151             temp_col=2
152             find_acl=1
153         ws.write(temp_row,temp_col,i)
154         ws.write(temp_row+1,temp_col,'0')
155         temp_col+=1
156     find_acl=0
157
158     for i in db:
159         curname=i['name']
160         curvalue='%.4f' % (i['val']/times)
161         if curname == 'ACL_BN':
162             ws.write(trow+2,7,curname)
163             ws.write(trow+3,7,curvalue)
164
165         if curname in name_list1:
166             val_col=name_list1.index(curname)+2
167             val_row=trow
168             # print ('name found'+ curname + curvalue)
169             # print(val_col)
170             # print (val_row)
171             if val_col>5:
172                 val_col-=4
173                 val_row+=2
174             ws.write(val_row,val_col,curname)
175             ws.write(val_row+1,val_col,curvalue)
176
177     tcol=0
178     trow+=4
179
180
181 if __name__ == '__main__' :
182     if len(sys.argv) < 2:
183         print(help_)
184         sys.exit()
185     else:
186         logfile = sys.argv[1]
187
188     filename = os.path.basename(logfile)
189     decodefile(logfile)
190
191     wb = xlwt.Workbook()
192     ws = wb.add_sheet('testsheet',True)
193     trow = 0
194     tcol = 0
195     cnt=1
196     times=1.0
197     table_val=''
198     print('1st time:')
199     ws.write(trow,tcol,'1st time')
200     tcol+=1
201     printresult(data_list1)
202
203     cnt=2
204     times=10.0
205     table_val=''
206     print('\nAverage of 2-11 times:')
207     ws.write(trow, tcol, '2-11 times')
208     tcol+=1
209     printresult(data_list2)
210     wb.save(filename+'.xls')
211     print ('Xls file generated')
212
213