Add CTS_ARB_gl_spirv test implementation
[platform/upstream/VK-GL-CTS.git] / scripts / testset.py
1 # -*- coding: utf-8 -*-
2
3 #-------------------------------------------------------------------------
4 # drawElements Quality Program utilities
5 # --------------------------------------
6 #
7 # Copyright 2015 The Android Open Source Project
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 #      http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 #
21 #-------------------------------------------------------------------------
22
23 import sys
24 import random
25 import string
26 import subprocess
27 from optparse import OptionParser
28
29 def all (results, predicate):
30         for result in results:
31                 if not predicate(result):
32                         return False
33         return True
34
35 def any (results, predicate):
36         for result in results:
37                 if predicate(result):
38                         return True
39         return False
40
41 class FilterRule:
42         def __init__ (self, name, description, filters):
43                 self.name                       = name
44                 self.description        = description
45                 self.filters            = filters
46
47 class TestCaseResult:
48         def __init__ (self, name, results):
49                 self.name               = name
50                 self.results    = results
51
52 class Group:
53         def __init__ (self, name):
54                 self.name               = name
55                 self.cases              = []
56
57 def readCaseList (filename):
58         f = open(filename, 'rb')
59         cases = []
60         for line in f:
61                 if line[:6] == "TEST: ":
62                         case = line[6:].strip()
63                         if len(case) > 0:
64                                 cases.append(case)
65         return cases
66
67 def toResultList (caselist):
68         results = []
69         for case in caselist:
70                 results.append(TestCaseResult(case, []))
71         return results
72
73 def addResultsToCaseList (caselist, results):
74         resultMap       = {}
75         caseListRes     = toResultList(caselist)
76
77         for res in caseListRes:
78                 resultMap[res.name] = res
79
80         for result in results:
81                 if result.name in resultMap:
82                         resultMap[result.name].results += result.results
83
84         return caseListRes
85
86 def readTestResults (filename):
87         f                       = open(filename, 'rb')
88         csvData         = f.read()
89         csvLines        = csvData.splitlines()
90         results         = []
91
92         f.close()
93
94         for line in csvLines[1:]:
95                 args = line.split(',')
96                 if len(args) == 1:
97                         continue # Ignore
98
99                 results.append(TestCaseResult(args[0], args[1:]))
100
101         if len(results) == 0:
102                 raise Exception("Empty result list")
103
104         # Sanity check for results
105         numResultItems  = len(results[0].results)
106         seenResults             = set()
107         for result in results:
108                 if result.name in seenResults:
109                         raise Exception("Duplicate result row for test case '%s'" % result.name)
110                 if len(result.results) != numResultItems:
111                         raise Exception("Found %d results for test case '%s', expected %d" % (len(result.results), result.name, numResultItems))
112                 seenResults.add(result.name)
113
114         return results
115
116 def readGroupList (filename):
117         f = open(filename, 'rb')
118         groups = []
119         for line in f:
120                 group = line.strip()
121                 if group != "":
122                         groups.append(group)
123         return groups
124
125 def createGroups (results, groupNames):
126         groups  = []
127         matched = set()
128
129         for groupName in groupNames:
130                 group = Group(groupName)
131                 groups.append(group)
132
133                 prefix          = groupName + "."
134                 prefixLen       = len(prefix)
135                 for case in results:
136                         if case.name[:prefixLen] == prefix:
137                                 if case in matched:
138                                         die("Case '%s' matched by multiple groups (when processing '%s')" % (case.name, group.name))
139                                 group.cases.append(case)
140                                 matched.add(case)
141
142         return groups
143
144 def createLeafGroups (results):
145         groups = []
146         groupMap = {}
147
148         for case in results:
149                 parts           = case.name.split('.')
150                 groupName       = string.join(parts[:-1], ".")
151
152                 if not groupName in groupMap:
153                         group = Group(groupName)
154                         groups.append(group)
155                         groupMap[groupName] = group
156                 else:
157                         group = groupMap[groupName]
158
159                 group.cases.append(case)
160
161         return groups
162
163 def filterList (results, condition):
164         filtered = []
165         for case in results:
166                 if condition(case.results):
167                         filtered.append(case)
168         return filtered
169
170 def getFilter (list, name):
171         for filter in list:
172                 if filter.name == name:
173                         return filter
174         return None
175
176 def getNumCasesInGroups (groups):
177         numCases = 0
178         for group in groups:
179                 numCases += len(group.cases)
180         return numCases
181
182 def getCasesInSet (results, caseSet):
183         filtered = []
184         for case in results:
185                 if case in caseSet:
186                         filtered.append(case)
187         return filtered
188
189 def selectCasesInGroups (results, groups):
190         casesInGroups = set()
191         for group in groups:
192                 for case in group.cases:
193                         casesInGroups.add(case)
194         return getCasesInSet(results, casesInGroups)
195
196 def selectRandomSubset (results, groups, limit, seed):
197         selectedCases   = set()
198         numSelect               = min(limit, getNumCasesInGroups(groups))
199
200         random.seed(seed)
201         random.shuffle(groups)
202
203         groupNdx = 0
204         while len(selectedCases) < numSelect:
205                 group = groups[groupNdx]
206                 if len(group.cases) == 0:
207                         del groups[groupNdx]
208                         if groupNdx == len(groups):
209                                 groupNdx -= 1
210                         continue # Try next
211
212                 selected = random.choice(group.cases)
213                 selectedCases.add(selected)
214                 group.cases.remove(selected)
215
216                 groupNdx = (groupNdx + 1) % len(groups)
217
218         return getCasesInSet(results, selectedCases)
219
220 def die (msg):
221         print msg
222         sys.exit(-1)
223
224 # Named filter lists
225 FILTER_RULES = [
226         FilterRule("all",                       "No filtering",                                                                                 []),
227         FilterRule("all-pass",          "All results must be 'Pass'",                                                   [lambda l: all(l, lambda r: r == 'Pass')]),
228         FilterRule("any-pass",          "Any of results is 'Pass'",                                                             [lambda l: any(l, lambda r: r == 'Pass')]),
229         FilterRule("any-fail",          "Any of results is not 'Pass' or 'NotSupported'",               [lambda l: not all(l, lambda r: r == 'Pass' or r == 'NotSupported')]),
230         FilterRule("prev-failing",      "Any except last result is failure",                                    [lambda l: l[-1] == 'Pass' and not all(l[:-1], lambda r: r == 'Pass')]),
231         FilterRule("prev-passing",      "Any except last result is 'Pass'",                                             [lambda l: l[-1] != 'Pass' and any(l[:-1], lambda r: r == 'Pass')])
232 ]
233
234 if __name__ == "__main__":
235         parser = OptionParser(usage = "usage: %prog [options] [caselist] [result csv file]")
236         parser.add_option("-f", "--filter", dest="filter", default="all", help="filter rule name")
237         parser.add_option("-l", "--list", action="store_true", dest="list", default=False, help="list available rules")
238         parser.add_option("-n", "--num", dest="limit", default=0, help="limit number of cases")
239         parser.add_option("-s", "--seed", dest="seed", default=0, help="use selected seed for random selection")
240         parser.add_option("-g", "--groups", dest="groups_file", default=None, help="select cases based on group list file")
241
242         (options, args) = parser.parse_args()
243
244         if options.list:
245                 print "Available filter rules:"
246                 for filter in FILTER_RULES:
247                         print "  %s: %s" % (filter.name, filter.description)
248                 sys.exit(0)
249
250         if len(args) == 0:
251                 die("No input files specified")
252         elif len(args) > 2:
253                 die("Too many arguments")
254
255         # Fetch filter
256         filter = getFilter(FILTER_RULES, options.filter)
257         if filter == None:
258                 die("Unknown filter '%s'" % options.filter)
259
260         # Read case list
261         caselist = readCaseList(args[0])
262         if len(args) > 1:
263                 results = readTestResults(args[1])
264                 results = addResultsToCaseList(caselist, results)
265         else:
266                 results = toResultList(caselist)
267
268         # Execute filters for results
269         for rule in filter.filters:
270                 results = filterList(results, rule)
271
272         if options.limit != 0:
273                 if options.groups_file != None:
274                         groups = createGroups(results, readGroupList(options.groups_file))
275                 else:
276                         groups = createLeafGroups(results)
277                 results = selectRandomSubset(results, groups, int(options.limit), int(options.seed))
278         elif options.groups_file != None:
279                 groups = createGroups(results, readGroupList(options.groups_file))
280                 results = selectCasesInGroups(results, groups)
281
282         # Print test set
283         for result in results:
284                 print result.name