Imported Upstream version 1.0.0
[platform/upstream/js.git] / js / src / metrics / jint / treesearch.py
1 # vim: set ts=4 sw=4 tw=99 et: 
2
3 import os, re
4 import tempfile
5 import subprocess
6 import sys, math
7 import datetime
8 import random
9
10 def realpath(k):
11     return os.path.realpath(os.path.normpath(k))
12
13 class UCTNode:
14     def __init__(self, loop):
15         self.children = None
16         self.loop = loop
17         self.visits = 1
18         self.score = 0
19
20     def addChild(self, child):
21         if self.children == None:
22             self.children = []
23         self.children.append(child)
24
25     def computeUCB(self, coeff):
26         return (self.score / self.visits) + math.sqrt(coeff / self.visits)
27
28 class UCT:
29     def __init__(self, benchmark, bestTime, enableLoops, loops, fd, playouts):
30         self.bm = benchmark
31         self.fd = fd
32         self.numPlayouts = playouts
33         self.maxNodes = self.numPlayouts * 20
34         self.loops = loops
35         self.enableLoops = enableLoops
36         self.maturityThreshold = 20
37         self.originalBest = bestTime
38         self.bestTime = bestTime
39         self.bias = 20
40         self.combos = []
41         self.zobrist = { }
42         random.seed()
43
44     def expandNode(self, node, pending):
45         for loop in pending:
46             node.addChild(UCTNode(loop))
47             self.numNodes += 1
48             if self.numNodes >= self.maxNodes:
49                 return False
50         return True
51
52     def findBestChild(self, node):
53         coeff = self.bias * math.log(node.visits)
54         bestChild = None
55         bestUCB = -float('Infinity')
56
57         for child in node.children:
58             ucb = child.computeUCB(coeff)
59             if ucb >= bestUCB:
60                 bestUCB = ucb
61                 bestChild = child
62
63         return child
64
65     def playout(self, history):
66         queue = []
67         for i in range(0, len(self.loops)):
68             queue.append(random.randint(0, 1))
69         for node in history:
70             queue[node.loop] = not self.enableLoops
71         zash = 0
72         for i in range(0, len(queue)):
73             if queue[i]:
74                 zash |= (1 << i)
75         if zash in self.zobrist:
76             return self.zobrist[zash]
77
78         self.bm.generateBanList(self.loops, queue)
79         result = self.bm.treeSearchRun(self.fd, ['-m', '-j'], 3)
80         self.zobrist[zash] = result
81         return result
82
83     def step(self, loopList):
84         node = self.root
85         pending = loopList[:]
86         history = [node]
87
88         while True:
89             # If this is a leaf node...
90             if node.children == None:
91                 # And the leaf node is mature...
92                 if node.visits >= self.maturityThreshold:
93                     # If the node can be expanded, keep spinning.
94                     if self.expandNode(node, pending) and node.children != None:
95                         continue
96
97                 # Otherwise, this is a leaf node. Run a playout.
98                 score = self.playout(history)
99                 break
100
101             # Find the best child.
102             node = self.findBestChild(node)
103             history.append(node)
104             pending.remove(node.loop)
105
106         # Normalize the score.
107         origScore = score
108         score = (self.originalBest - score) / self.originalBest
109
110         for node in history:
111             node.visits += 1
112             node.score += score
113
114         if int(origScore) < int(self.bestTime):
115             print('New best score: {0:f}ms'.format(origScore))
116             self.combos = [history]
117             self.bestTime = origScore
118         elif int(origScore) == int(self.bestTime):
119             self.combos.append(history)
120
121     def run(self):
122         loopList = [i for i in range(0, len(self.loops))]
123         self.numNodes = 1
124         self.root = UCTNode(-1)
125         self.expandNode(self.root, loopList)
126
127         for i in range(0, self.numPlayouts):
128             self.step(loopList)
129
130         # Build the expected combination vector.
131         combos = [ ]
132         for combo in self.combos:
133             vec = [ ]
134             for i in range(0, len(self.loops)):
135                 vec.append(int(self.enableLoops))
136             for node in combo:
137                 vec[node.loop] = int(not self.enableLoops)
138             combos.append(vec)
139
140         return [self.bestTime, combos]
141
142 class Benchmark:
143     def __init__(self, JS, fname):
144         self.fname = fname
145         self.JS = JS
146         self.stats = { }
147         self.runList = [ ]
148
149     def run(self, fd, eargs):
150         args = [self.JS]
151         args.extend(eargs)
152         args.append(fd.name)
153         return subprocess.check_output(args).decode()
154
155     #    self.stats[name] = { }
156     #    self.runList.append(name)
157     #    for line in output.split('\n'):
158     #        m = re.search('line (\d+): (\d+)', line)
159     #        if m:
160     #            self.stats[name][int(m.group(1))] = int(m.group(2))
161     #        else:
162     #            m = re.search('total: (\d+)', line)
163     #            if m:
164     #                self.stats[name]['total'] = m.group(1)
165
166     def winnerForLine(self, line):
167         best = self.runList[0]
168         bestTime = self.stats[best][line]
169         for run in self.runList[1:]:
170             x = self.stats[run][line]
171             if x < bestTime:
172                 best = run
173                 bestTime = x
174         return best
175
176     def chart(self):
177         sys.stdout.write('{0:7s}'.format(''))
178         sys.stdout.write('{0:15s}'.format('line'))
179         for run in self.runList:
180             sys.stdout.write('{0:15s}'.format(run))
181         sys.stdout.write('{0:15s}\n'.format('best'))
182         for c in self.counters:
183             sys.stdout.write('{0:10d}'.format(c))
184             for run in self.runList:
185                 sys.stdout.write('{0:15d}'.format(self.stats[run][c]))
186             sys.stdout.write('{0:12s}'.format(''))
187             sys.stdout.write('{0:15s}'.format(self.winnerForLine(c)))
188             sys.stdout.write('\n')
189
190     def preprocess(self, lines, onBegin, onEnd):
191         stack = []
192         counters = []
193         rd = open(self.fname, 'rt')
194         for line in rd:
195             if re.search('\/\* BEGIN LOOP \*\/', line):
196                 stack.append([len(lines), len(counters)])
197                 counters.append([len(lines), 0])
198                 onBegin(lines, len(lines))
199             elif re.search('\/\* END LOOP \*\/', line):
200                 old = stack.pop()
201                 onEnd(lines, old[0], len(lines))
202                 counters[old[1]][1] = len(lines)
203             else:
204                 lines.append(line)
205         return [lines, counters]
206
207     def treeSearchRun(self, fd, args, count = 5):
208         total = 0
209         for i in range(0, count):
210             output = self.run(fd, args)
211             total += int(output)
212         return total / count
213
214     def generateBanList(self, counters, queue):
215         if os.path.exists('/tmp/permabans'):
216             os.unlink('/tmp/permabans')
217         fd = open('/tmp/permabans', 'wt')
218         for i in range(0, len(counters)):
219             for j in range(counters[i][0], counters[i][1] + 1):
220                 fd.write('{0:d} {1:d}\n'.format(j, int(queue[i])))
221         fd.close()
222
223     def internalExhaustiveSearch(self, params):
224         counters = params['counters']
225
226         # iterative algorithm to explore every combination
227         ncombos = 2 ** len(counters)
228         queue = []
229         for c in counters:
230             queue.append(0)
231
232         fd = params['fd']
233         bestTime = float('Infinity')
234         bestCombos = []
235
236         i = 0
237         while i < ncombos:
238             temp = i
239             for j in range(0, len(counters)):
240                 queue[j] = temp & 1
241                 temp = temp >> 1
242             self.generateBanList(counters, queue)
243
244             t = self.treeSearchRun(fd, ['-m', '-j'])
245             if (t < bestTime):
246                 bestTime = t
247                 bestCombos = [queue[:]]
248                 print('New best time: {0:f}ms'.format(t))
249             elif int(t) == int(bestTime):
250                 bestCombos.append(queue[:])
251
252             i = i + 1
253
254         return [bestTime, bestCombos]
255
256     def internalTreeSearch(self, params):
257         fd = params['fd']
258         methodTime = params['methodTime']
259         tracerTime = params['tracerTime']
260         combinedTime = params['combinedTime']
261         counters = params['counters']
262
263         # Build the initial loop data.
264         # If the method JIT already wins, disable tracing by default.
265         # Otherwise, enable tracing by default.
266         if methodTime < combinedTime:
267             enableLoops = True
268         else:
269             enableLoops = False
270
271         enableLoops = False
272
273         uct = UCT(self, combinedTime, enableLoops, counters[:], fd, 50000)
274         return uct.run()
275
276     def treeSearch(self):
277         fd, counters = self.ppForTreeSearch()
278
279         os.system("cat " + fd.name + " > /tmp/k.js")
280
281         if os.path.exists('/tmp/permabans'):
282             os.unlink('/tmp/permabans')
283         methodTime = self.treeSearchRun(fd, ['-m'])
284         tracerTime = self.treeSearchRun(fd, ['-j'])
285         combinedTime = self.treeSearchRun(fd, ['-m', '-j'])
286
287         #Get a rough estimate of how long this benchmark will take to fully compute.
288         upperBound = max(methodTime, tracerTime, combinedTime)
289         upperBound *= 2 ** len(counters)
290         upperBound *= 5    # Number of runs
291         treeSearch = False
292         if (upperBound < 1000):
293             print('Estimating {0:d}ms to test, so picking exhaustive '.format(int(upperBound)) +
294                   'search.')
295         else:
296             upperBound = int(upperBound / 1000)
297             delta = datetime.timedelta(seconds = upperBound)
298             if upperBound < 180:
299                 print('Estimating {0:d}s to test, so picking exhaustive '.format(int(upperBound)))
300             else:
301                 print('Estimating {0:s} to test, so picking tree search '.format(str(delta)))
302                 treeSearch = True
303
304         best = min(methodTime, tracerTime, combinedTime)
305
306         params = {
307                     'fd': fd,
308                     'counters': counters,
309                     'methodTime': methodTime,
310                     'tracerTime': tracerTime,
311                     'combinedTime': combinedTime
312                  }
313
314         print('Method JIT:  {0:d}ms'.format(int(methodTime)))
315         print('Tracing JIT: {0:d}ms'.format(int(tracerTime)))
316         print('Combined:    {0:d}ms'.format(int(combinedTime)))
317
318         if 1 and treeSearch:
319             results = self.internalTreeSearch(params)
320         else:
321             results = self.internalExhaustiveSearch(params)
322
323         bestTime = results[0]
324         bestCombos = results[1]
325         print('Search found winning time {0:d}ms!'.format(int(bestTime)))
326         print('Combos at this time: {0:d}'.format(len(bestCombos)))
327
328         #Find loops that traced every single time
329         for i in range(0, len(counters)):
330             start = counters[i][0]
331             end = counters[i][1]
332             n = len(bestCombos)
333             for j in bestCombos:
334                 n -= j[i]
335             print('\tloop @ {0:d}-{1:d} traced {2:d}% of the time'.format(
336                     start, end, int(n / len(bestCombos) * 100)))
337
338     def ppForTreeSearch(self):
339         def onBegin(lines, lineno):
340             lines.append('GLOBAL_THINGY = 1;\n')
341         def onEnd(lines, old, lineno):
342             lines.append('GLOBAL_THINGY = 1;\n')
343
344         lines = ['var JINT_START_TIME = Date.now();\n',
345                  'var GLOBAL_THINGY = 0;\n']
346
347         lines, counters = self.preprocess(lines, onBegin, onEnd)
348         fd = tempfile.NamedTemporaryFile('wt')
349         for line in lines:
350             fd.write(line)
351         fd.write('print(Date.now() - JINT_START_TIME);\n')
352         fd.flush()
353         return [fd, counters]
354
355     def preprocessForLoopCounting(self):
356         def onBegin(lines, lineno):
357             lines.append('JINT_TRACKER.line_' + str(lineno) + '_start = Date.now();\n')
358
359         def onEnd(lines, old, lineno):
360             lines.append('JINT_TRACKER.line_' + str(old) + '_end = Date.now();\n')
361             lines.append('JINT_TRACKER.line_' + str(old) + '_total += ' + \
362                          'JINT_TRACKER.line_' + str(old) + '_end - ' + \
363                          'JINT_TRACKER.line_' + str(old) + '_start;\n')
364
365         lines, counters = self.preprocess(onBegin, onEnd)
366         fd = tempfile.NamedTemporaryFile('wt')
367         fd.write('var JINT_TRACKER = { };\n')
368         for c in counters:
369             fd.write('JINT_TRACKER.line_' + str(c) + '_start = 0;\n')
370             fd.write('JINT_TRACKER.line_' + str(c) + '_end = 0;\n')
371             fd.write('JINT_TRACKER.line_' + str(c) + '_total = 0;\n')
372         fd.write('JINT_TRACKER.begin = Date.now();\n')
373         for line in lines:
374             fd.write(line)
375         fd.write('JINT_TRACKER.total = Date.now() - JINT_TRACKER.begin;\n')
376         for c in self.counters:
377             fd.write('print("line ' + str(c) + ': " + JINT_TRACKER.line_' + str(c) +
378                            '_total);')
379         fd.write('print("total: " + JINT_TRACKER.total);')
380         fd.flush()
381         return fd
382
383 if __name__ == '__main__':
384     script_path = os.path.abspath(__file__)
385     script_dir = os.path.dirname(script_path)
386     test_dir = os.path.join(script_dir, 'tests')
387     lib_dir = os.path.join(script_dir, 'lib')
388
389     # The [TESTS] optional arguments are paths of test files relative
390     # to the jit-test/tests directory.
391
392     from optparse import OptionParser
393     op = OptionParser(usage='%prog [options] JS_SHELL test')
394     (OPTIONS, args) = op.parse_args()
395     if len(args) < 2:
396         op.error('missing JS_SHELL and test argument')
397     # We need to make sure we are using backslashes on Windows.
398     JS = realpath(args[0])
399     test = realpath(args[1])
400
401     bm = Benchmark(JS, test)
402     bm.treeSearch()
403     # bm.preprocess()
404     # bm.run('mjit', ['-m'])
405     # bm.run('tjit', ['-j'])
406     # bm.run('m+tjit', ['-m', '-j'])
407     # bm.chart()
408