Merge pull request #276 from Ella-0/master
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / reconvergence / vktReconvergenceTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2020 NVIDIA Corporation
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *        http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Vulkan Reconvergence tests
23  *//*--------------------------------------------------------------------*/
24
25 #include "vktReconvergenceTests.hpp"
26
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47
48 #include <bitset>
49 #include <string>
50 #include <sstream>
51 #include <set>
52 #include <vector>
53
54 namespace vkt
55 {
56 namespace Reconvergence
57 {
58 namespace
59 {
60 using namespace vk;
61 using namespace std;
62
63 #define ARRAYSIZE(x) (sizeof(x) / sizeof(x[0]))
64
65 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
66
67 typedef enum {
68         TT_SUCF_ELECT,  // subgroup_uniform_control_flow using elect (subgroup_basic)
69         TT_SUCF_BALLOT, // subgroup_uniform_control_flow using ballot (subgroup_ballot)
70         TT_WUCF_ELECT,  // workgroup uniform control flow using elect (subgroup_basic)
71         TT_WUCF_BALLOT, // workgroup uniform control flow using ballot (subgroup_ballot)
72         TT_MAXIMAL,             // maximal reconvergence
73 } TestType;
74
75 struct CaseDef
76 {
77         TestType testType;
78         deUint32 maxNesting;
79         deUint32 seed;
80
81         bool isWUCF() const { return testType == TT_WUCF_ELECT || testType == TT_WUCF_BALLOT; }
82         bool isSUCF() const { return testType == TT_SUCF_ELECT || testType == TT_SUCF_BALLOT; }
83         bool isUCF() const { return isWUCF() || isSUCF(); }
84         bool isElect() const { return testType == TT_WUCF_ELECT || testType == TT_SUCF_ELECT; }
85 };
86
87 deUint64 subgroupSizeToMask(deUint32 subgroupSize)
88 {
89         if (subgroupSize == 64)
90                 return ~0ULL;
91         else
92                 return (1ULL << subgroupSize) - 1;
93 }
94
95 typedef std::bitset<128> bitset128;
96
97 // Take a 64-bit integer, mask it to the subgroup size, and then
98 // replicate it for each subgroup
99 bitset128 bitsetFromU64(deUint64 mask, deUint32 subgroupSize)
100 {
101         mask &= subgroupSizeToMask(subgroupSize);
102         bitset128 result(mask);
103         for (deUint32 i = 0; i < 128 / subgroupSize - 1; ++i)
104         {
105                 result = (result << subgroupSize) | bitset128(mask);
106         }
107         return result;
108 }
109
110 // Pick out the mask for the subgroup that invocationID is a member of
111 deUint64 bitsetToU64(const bitset128 &bitset, deUint32 subgroupSize, deUint32 invocationID)
112 {
113         bitset128 copy(bitset);
114         copy >>= (invocationID / subgroupSize) * subgroupSize;
115         copy &= bitset128(subgroupSizeToMask(subgroupSize));
116         deUint64 mask = copy.to_ullong();
117         mask &= subgroupSizeToMask(subgroupSize);
118         return mask;
119 }
120
121 class ReconvergenceTestInstance : public TestInstance
122 {
123 public:
124                                                 ReconvergenceTestInstance       (Context& context, const CaseDef& data);
125                                                 ~ReconvergenceTestInstance      (void);
126         tcu::TestStatus         iterate                         (void);
127 private:
128         CaseDef                 m_data;
129 };
130
131 ReconvergenceTestInstance::ReconvergenceTestInstance (Context& context, const CaseDef& data)
132         : vkt::TestInstance             (context)
133         , m_data                                (data)
134 {
135 }
136
137 ReconvergenceTestInstance::~ReconvergenceTestInstance (void)
138 {
139 }
140
141 class ReconvergenceTestCase : public TestCase
142 {
143         public:
144                                                                 ReconvergenceTestCase           (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
145                                                                 ~ReconvergenceTestCase  (void);
146         virtual void                            initPrograms            (SourceCollections& programCollection) const;
147         virtual TestInstance*           createInstance          (Context& context) const;
148         virtual void                            checkSupport            (Context& context) const;
149
150 private:
151         CaseDef                                 m_data;
152 };
153
154 ReconvergenceTestCase::ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
155         : vkt::TestCase (context, name, desc)
156         , m_data                (data)
157 {
158 }
159
160 ReconvergenceTestCase::~ReconvergenceTestCase   (void)
161 {
162 }
163
164 void ReconvergenceTestCase::checkSupport(Context& context) const
165 {
166         if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
167                 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
168
169         vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
170         deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
171         subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
172
173         vk::VkPhysicalDeviceProperties2 properties2;
174         deMemset(&properties2, 0, sizeof(properties2));
175         properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
176         properties2.pNext = &subgroupProperties;
177
178         context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties2);
179
180         if (m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BASIC_BIT))
181                 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BASIC_BIT not supported");
182
183         if (!m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT))
184                 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BALLOT_BIT not supported");
185
186         if (!(context.getSubgroupProperties().supportedStages & VK_SHADER_STAGE_COMPUTE_BIT))
187                 TCU_THROW(NotSupportedError, "compute stage does not support subgroup operations");
188
189         // Both subgroup- AND workgroup-uniform tests are enabled by shaderSubgroupUniformControlFlow.
190         if (m_data.isUCF() && !context.getShaderSubgroupUniformControlFlowFeatures().shaderSubgroupUniformControlFlow)
191                 TCU_THROW(NotSupportedError, "shaderSubgroupUniformControlFlow not supported");
192
193         // XXX TODO: Check for maximal reconvergence support
194         // if (m_data.testType == TT_MAXIMAL ...)
195 }
196
197 typedef enum
198 {
199         // store subgroupBallot().
200         // For OP_BALLOT, OP::caseValue is initialized to zero, and then
201         // set to 1 by simulate if the ballot is not workgroup- (or subgroup-_uniform.
202         // Only workgroup-uniform ballots are validated for correctness in
203         // WUCF modes.
204         OP_BALLOT,
205
206         // store literal constant
207         OP_STORE,
208
209         // if ((1ULL << gl_SubgroupInvocationID) & mask).
210         // Special case if mask = ~0ULL, converted into "if (inputA.a[idx] == idx)"
211         OP_IF_MASK,
212         OP_ELSE_MASK,
213         OP_ENDIF,
214
215         // if (gl_SubgroupInvocationID == loopIdxN) (where N is most nested loop counter)
216         OP_IF_LOOPCOUNT,
217         OP_ELSE_LOOPCOUNT,
218
219         // if (gl_LocalInvocationIndex >= inputA.a[N]) (where N is most nested loop counter)
220         OP_IF_LOCAL_INVOCATION_INDEX,
221         OP_ELSE_LOCAL_INVOCATION_INDEX,
222
223         // break/continue
224         OP_BREAK,
225         OP_CONTINUE,
226
227         // if (subgroupElect())
228         OP_ELECT,
229
230         // Loop with uniform number of iterations (read from a buffer)
231         OP_BEGIN_FOR_UNIF,
232         OP_END_FOR_UNIF,
233
234         // for (int loopIdxN = 0; loopIdxN < gl_SubgroupInvocationID + 1; ++loopIdxN)
235         OP_BEGIN_FOR_VAR,
236         OP_END_FOR_VAR,
237
238         // for (int loopIdxN = 0;; ++loopIdxN, OP_BALLOT)
239         // Always has an "if (subgroupElect()) break;" inside.
240         // Does the equivalent of OP_BALLOT in the continue construct
241         OP_BEGIN_FOR_INF,
242         OP_END_FOR_INF,
243
244         // do { loopIdxN++; ... } while (loopIdxN < uniformValue);
245         OP_BEGIN_DO_WHILE_UNIF,
246         OP_END_DO_WHILE_UNIF,
247
248         // do { ... } while (true);
249         // Always has an "if (subgroupElect()) break;" inside
250         OP_BEGIN_DO_WHILE_INF,
251         OP_END_DO_WHILE_INF,
252
253         // return;
254         OP_RETURN,
255
256         // function call (code bracketed by these is extracted into a separate function)
257         OP_CALL_BEGIN,
258         OP_CALL_END,
259
260         // switch statement on uniform value
261         OP_SWITCH_UNIF_BEGIN,
262         // switch statement on gl_SubgroupInvocationID & 3 value
263         OP_SWITCH_VAR_BEGIN,
264         // switch statement on loopIdx value
265         OP_SWITCH_LOOP_COUNT_BEGIN,
266
267         // case statement with a (invocation mask, case mask) pair
268         OP_CASE_MASK_BEGIN,
269         // case statement used for loop counter switches, with a value and a mask of loop iterations
270         OP_CASE_LOOP_COUNT_BEGIN,
271
272         // end of switch/case statement
273         OP_SWITCH_END,
274         OP_CASE_END,
275
276         // Extra code with no functional effect. Currently inculdes:
277         // - value 0: while (!subgroupElect()) {}
278         // - value 1: if (condition_that_is_false) { infinite loop }
279         OP_NOISE,
280 } OPType;
281
282 typedef enum
283 {
284         // Different if test conditions
285         IF_MASK,
286         IF_UNIFORM,
287         IF_LOOPCOUNT,
288         IF_LOCAL_INVOCATION_INDEX,
289 } IFType;
290
291 class OP
292 {
293 public:
294         OP(OPType _type, deUint64 _value, deUint32 _caseValue = 0)
295                 : type(_type), value(_value), caseValue(_caseValue)
296         {}
297
298         // The type of operation and an optional value.
299         // The value could be a mask for an if test, the index of the loop
300         // header for an end of loop, or the constant value for a store instruction
301         OPType type;
302         deUint64 value;
303         deUint32 caseValue;
304 };
305
306 static int findLSB (deUint64 value)
307 {
308         for (int i = 0; i < 64; i++)
309         {
310                 if (value & (1ULL<<i))
311                         return i;
312         }
313         return -1;
314 }
315
316 // For each subgroup, pick out the elected invocationID, and accumulate
317 // a bitset of all of them
318 static bitset128 bitsetElect (const bitset128& value, deInt32 subgroupSize)
319 {
320         bitset128 ret; // zero initialized
321
322         for (deInt32 i = 0; i < 128; i += subgroupSize)
323         {
324                 deUint64 mask = bitsetToU64(value, subgroupSize, i);
325                 int lsb = findLSB(mask);
326                 ret |= bitset128(lsb == -1 ? 0 : (1ULL << lsb)) << i;
327         }
328         return ret;
329 }
330
331 class RandomProgram
332 {
333 public:
334         RandomProgram(const CaseDef &c)
335                 : caseDef(c), numMasks(5), nesting(0), maxNesting(c.maxNesting), loopNesting(0), loopNestingThisFunction(0), callNesting(0), minCount(30), indent(0), isLoopInf(100, false), doneInfLoopBreak(100, false), storeBase(0x10000)
336         {
337                 deRandom_init(&rnd, caseDef.seed);
338                 for (int i = 0; i < numMasks; ++i)
339                         masks.push_back(deRandom_getUint64(&rnd));
340         }
341
342         const CaseDef caseDef;
343         deRandom rnd;
344         vector<OP> ops;
345         vector<deUint64> masks;
346         deInt32 numMasks;
347         deInt32 nesting;
348         deInt32 maxNesting;
349         deInt32 loopNesting;
350         deInt32 loopNestingThisFunction;
351         deInt32 callNesting;
352         deInt32 minCount;
353         deInt32 indent;
354         vector<bool> isLoopInf;
355         vector<bool> doneInfLoopBreak;
356         // Offset the value we use for OP_STORE, to avoid colliding with fully converged
357         // active masks with small subgroup sizes (e.g. with subgroupSize == 4, the SUCF
358         // tests need to know that 0xF is really an active mask).
359         deInt32 storeBase;
360
361         void genIf(IFType ifType)
362         {
363                 deUint32 maskIdx = deRandom_getUint32(&rnd) % numMasks;
364                 deUint64 mask = masks[maskIdx];
365                 if (ifType == IF_UNIFORM)
366                         mask = ~0ULL;
367
368                 deUint32 localIndexCmp = deRandom_getUint32(&rnd) % 128;
369                 if (ifType == IF_LOCAL_INVOCATION_INDEX)
370                         ops.push_back({OP_IF_LOCAL_INVOCATION_INDEX, localIndexCmp});
371                 else if (ifType == IF_LOOPCOUNT)
372                         ops.push_back({OP_IF_LOOPCOUNT, 0});
373                 else
374                         ops.push_back({OP_IF_MASK, mask});
375
376                 nesting++;
377
378                 size_t thenBegin = ops.size();
379                 pickOP(2);
380                 size_t thenEnd = ops.size();
381
382                 deUint32 randElse = (deRandom_getUint32(&rnd) % 100);
383                 if (randElse < 50)
384                 {
385                         if (ifType == IF_LOCAL_INVOCATION_INDEX)
386                                 ops.push_back({OP_ELSE_LOCAL_INVOCATION_INDEX, localIndexCmp});
387                         else if (ifType == IF_LOOPCOUNT)
388                                 ops.push_back({OP_ELSE_LOOPCOUNT, 0});
389                         else
390                                 ops.push_back({OP_ELSE_MASK, 0});
391
392                         if (randElse < 10)
393                         {
394                                 // Sometimes make the else block identical to the then block
395                                 for (size_t i = thenBegin; i < thenEnd; ++i)
396                                         ops.push_back(ops[i]);
397                         }
398                         else
399                                 pickOP(2);
400                 }
401                 ops.push_back({OP_ENDIF, 0});
402                 nesting--;
403         }
404
405         void genForUnif()
406         {
407                 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
408                 ops.push_back({OP_BEGIN_FOR_UNIF, iterCount});
409                 deUint32 loopheader = (deUint32)ops.size()-1;
410                 nesting++;
411                 loopNesting++;
412                 loopNestingThisFunction++;
413                 pickOP(2);
414                 ops.push_back({OP_END_FOR_UNIF, loopheader});
415                 loopNestingThisFunction--;
416                 loopNesting--;
417                 nesting--;
418         }
419
420         void genDoWhileUnif()
421         {
422                 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
423                 ops.push_back({OP_BEGIN_DO_WHILE_UNIF, iterCount});
424                 deUint32 loopheader = (deUint32)ops.size()-1;
425                 nesting++;
426                 loopNesting++;
427                 loopNestingThisFunction++;
428                 pickOP(2);
429                 ops.push_back({OP_END_DO_WHILE_UNIF, loopheader});
430                 loopNestingThisFunction--;
431                 loopNesting--;
432                 nesting--;
433         }
434
435         void genForVar()
436         {
437                 ops.push_back({OP_BEGIN_FOR_VAR, 0});
438                 deUint32 loopheader = (deUint32)ops.size()-1;
439                 nesting++;
440                 loopNesting++;
441                 loopNestingThisFunction++;
442                 pickOP(2);
443                 ops.push_back({OP_END_FOR_VAR, loopheader});
444                 loopNestingThisFunction--;
445                 loopNesting--;
446                 nesting--;
447         }
448
449         void genForInf()
450         {
451                 ops.push_back({OP_BEGIN_FOR_INF, 0});
452                 deUint32 loopheader = (deUint32)ops.size()-1;
453
454                 nesting++;
455                 loopNesting++;
456                 loopNestingThisFunction++;
457                 isLoopInf[loopNesting] = true;
458                 doneInfLoopBreak[loopNesting] = false;
459
460                 pickOP(2);
461
462                 genElect(true);
463                 doneInfLoopBreak[loopNesting] = true;
464
465                 pickOP(2);
466
467                 ops.push_back({OP_END_FOR_INF, loopheader});
468
469                 isLoopInf[loopNesting] = false;
470                 doneInfLoopBreak[loopNesting] = false;
471                 loopNestingThisFunction--;
472                 loopNesting--;
473                 nesting--;
474         }
475
476         void genDoWhileInf()
477         {
478                 ops.push_back({OP_BEGIN_DO_WHILE_INF, 0});
479                 deUint32 loopheader = (deUint32)ops.size()-1;
480
481                 nesting++;
482                 loopNesting++;
483                 loopNestingThisFunction++;
484                 isLoopInf[loopNesting] = true;
485                 doneInfLoopBreak[loopNesting] = false;
486
487                 pickOP(2);
488
489                 genElect(true);
490                 doneInfLoopBreak[loopNesting] = true;
491
492                 pickOP(2);
493
494                 ops.push_back({OP_END_DO_WHILE_INF, loopheader});
495
496                 isLoopInf[loopNesting] = false;
497                 doneInfLoopBreak[loopNesting] = false;
498                 loopNestingThisFunction--;
499                 loopNesting--;
500                 nesting--;
501         }
502
503         void genBreak()
504         {
505                 if (loopNestingThisFunction > 0)
506                 {
507                         // Sometimes put the break in a divergent if
508                         if ((deRandom_getUint32(&rnd) % 100) < 10)
509                         {
510                                 ops.push_back({OP_IF_MASK, masks[0]});
511                                 ops.push_back({OP_BREAK, 0});
512                                 ops.push_back({OP_ELSE_MASK, 0});
513                                 ops.push_back({OP_BREAK, 0});
514                                 ops.push_back({OP_ENDIF, 0});
515                         }
516                         else
517                                 ops.push_back({OP_BREAK, 0});
518                 }
519         }
520
521         void genContinue()
522         {
523                 // continues are allowed if we're in a loop and the loop is not infinite,
524                 // or if it is infinite and we've already done a subgroupElect+break.
525                 // However, adding more continues seems to reduce the failure rate, so
526                 // disable it for now
527                 if (loopNestingThisFunction > 0 && !(isLoopInf[loopNesting] /*&& !doneInfLoopBreak[loopNesting]*/))
528                 {
529                         // Sometimes put the continue in a divergent if
530                         if ((deRandom_getUint32(&rnd) % 100) < 10)
531                         {
532                                 ops.push_back({OP_IF_MASK, masks[0]});
533                                 ops.push_back({OP_CONTINUE, 0});
534                                 ops.push_back({OP_ELSE_MASK, 0});
535                                 ops.push_back({OP_CONTINUE, 0});
536                                 ops.push_back({OP_ENDIF, 0});
537                         }
538                         else
539                                 ops.push_back({OP_CONTINUE, 0});
540                 }
541         }
542
543         // doBreak is used to generate "if (subgroupElect()) { ... break; }" inside infinite loops
544         void genElect(bool doBreak)
545         {
546                 ops.push_back({OP_ELECT, 0});
547                 nesting++;
548                 if (doBreak)
549                 {
550                         // Put something interestign before the break
551                         optBallot();
552                         optBallot();
553                         if ((deRandom_getUint32(&rnd) % 100) < 10)
554                                 pickOP(1);
555
556                         // if we're in a function, sometimes  use return instead
557                         if (callNesting > 0 && (deRandom_getUint32(&rnd) % 100) < 30)
558                                 ops.push_back({OP_RETURN, 0});
559                         else
560                                 genBreak();
561
562                 }
563                 else
564                         pickOP(2);
565
566                 ops.push_back({OP_ENDIF, 0});
567                 nesting--;
568         }
569
570         void genReturn()
571         {
572                 deUint32 r = deRandom_getUint32(&rnd) % 100;
573                 if (nesting > 0 &&
574                         // Use return rarely in main, 20% of the time in a singly nested loop in a function
575                         // and 50% of the time in a multiply nested loop in a function
576                         (r < 5 ||
577                          (callNesting > 0 && loopNestingThisFunction > 0 && r < 20) ||
578                          (callNesting > 0 && loopNestingThisFunction > 1 && r < 50)))
579                 {
580                         optBallot();
581                         if ((deRandom_getUint32(&rnd) % 100) < 10)
582                         {
583                                 ops.push_back({OP_IF_MASK, masks[0]});
584                                 ops.push_back({OP_RETURN, 0});
585                                 ops.push_back({OP_ELSE_MASK, 0});
586                                 ops.push_back({OP_RETURN, 0});
587                                 ops.push_back({OP_ENDIF, 0});
588                         }
589                         else
590                                 ops.push_back({OP_RETURN, 0});
591                 }
592         }
593
594         // Generate a function call. Save and restore some loop information, which is used to
595         // determine when it's safe to use break/continue
596         void genCall()
597         {
598                 ops.push_back({OP_CALL_BEGIN, 0});
599                 callNesting++;
600                 nesting++;
601                 deInt32 saveLoopNestingThisFunction = loopNestingThisFunction;
602                 loopNestingThisFunction = 0;
603
604                 pickOP(2);
605
606                 loopNestingThisFunction = saveLoopNestingThisFunction;
607                 nesting--;
608                 callNesting--;
609                 ops.push_back({OP_CALL_END, 0});
610         }
611
612         // Generate switch on a uniform value:
613         // switch (inputA.a[r]) {
614         // case r+1: ... break; // should not execute
615         // case r:   ... break; // should branch uniformly
616         // case r+2: ... break; // should not execute
617         // }
618         void genSwitchUnif()
619         {
620                 deUint32 r = deRandom_getUint32(&rnd) % 5;
621                 ops.push_back({OP_SWITCH_UNIF_BEGIN, r});
622                 nesting++;
623
624                 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+1)});
625                 pickOP(1);
626                 ops.push_back({OP_CASE_END, 0});
627
628                 ops.push_back({OP_CASE_MASK_BEGIN, ~0ULL, 1u<<r});
629                 pickOP(2);
630                 ops.push_back({OP_CASE_END, 0});
631
632                 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+2)});
633                 pickOP(1);
634                 ops.push_back({OP_CASE_END, 0});
635
636                 ops.push_back({OP_SWITCH_END, 0});
637                 nesting--;
638         }
639
640         // switch (gl_SubgroupInvocationID & 3) with four unique targets
641         void genSwitchVar()
642         {
643                 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
644                 nesting++;
645
646                 ops.push_back({OP_CASE_MASK_BEGIN, 0x1111111111111111ULL, 1<<0});
647                 pickOP(1);
648                 ops.push_back({OP_CASE_END, 0});
649
650                 ops.push_back({OP_CASE_MASK_BEGIN, 0x2222222222222222ULL, 1<<1});
651                 pickOP(1);
652                 ops.push_back({OP_CASE_END, 0});
653
654                 ops.push_back({OP_CASE_MASK_BEGIN, 0x4444444444444444ULL, 1<<2});
655                 pickOP(1);
656                 ops.push_back({OP_CASE_END, 0});
657
658                 ops.push_back({OP_CASE_MASK_BEGIN, 0x8888888888888888ULL, 1<<3});
659                 pickOP(1);
660                 ops.push_back({OP_CASE_END, 0});
661
662                 ops.push_back({OP_SWITCH_END, 0});
663                 nesting--;
664         }
665
666         // switch (gl_SubgroupInvocationID & 3) with two shared targets.
667         // XXX TODO: The test considers these two targets to remain converged,
668         // though we haven't agreed to that behavior yet.
669         void genSwitchMulticase()
670         {
671                 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
672                 nesting++;
673
674                 ops.push_back({OP_CASE_MASK_BEGIN, 0x3333333333333333ULL, (1<<0)|(1<<1)});
675                 pickOP(2);
676                 ops.push_back({OP_CASE_END, 0});
677
678                 ops.push_back({OP_CASE_MASK_BEGIN, 0xCCCCCCCCCCCCCCCCULL, (1<<2)|(1<<3)});
679                 pickOP(2);
680                 ops.push_back({OP_CASE_END, 0});
681
682                 ops.push_back({OP_SWITCH_END, 0});
683                 nesting--;
684         }
685
686         // switch (loopIdxN) {
687         // case 1:  ... break;
688         // case 2:  ... break;
689         // default: ... break;
690         // }
691         void genSwitchLoopCount()
692         {
693                 deUint32 r = deRandom_getUint32(&rnd) % loopNesting;
694                 ops.push_back({OP_SWITCH_LOOP_COUNT_BEGIN, r});
695                 nesting++;
696
697                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<1, 1});
698                 pickOP(1);
699                 ops.push_back({OP_CASE_END, 0});
700
701                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<2, 2});
702                 pickOP(1);
703                 ops.push_back({OP_CASE_END, 0});
704
705                 // default:
706                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, ~6ULL, 0xFFFFFFFF});
707                 pickOP(1);
708                 ops.push_back({OP_CASE_END, 0});
709
710                 ops.push_back({OP_SWITCH_END, 0});
711                 nesting--;
712         }
713
714         void pickOP(deUint32 count)
715         {
716                 // Pick "count" instructions. These can recursively insert more instructions,
717                 // so "count" is just a seed
718                 for (deUint32 i = 0; i < count; ++i)
719                 {
720                         optBallot();
721                         if (nesting < maxNesting)
722                         {
723                                 deUint32 r = deRandom_getUint32(&rnd) % 11;
724                                 switch (r)
725                                 {
726                                 default:
727                                         DE_ASSERT(0);
728                                         // fallthrough
729                                 case 2:
730                                         if (loopNesting)
731                                         {
732                                                 genIf(IF_LOOPCOUNT);
733                                                 break;
734                                         }
735                                         // fallthrough
736                                 case 10:
737                                         genIf(IF_LOCAL_INVOCATION_INDEX);
738                                         break;
739                                 case 0:
740                                         genIf(IF_MASK);
741                                         break;
742                                 case 1:
743                                         genIf(IF_UNIFORM);
744                                         break;
745                                 case 3:
746                                         {
747                                                 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
748                                                 if (loopNesting <= 3)
749                                                 {
750                                                         deUint32 r2 = deRandom_getUint32(&rnd) % 3;
751                                                         switch (r2)
752                                                         {
753                                                         default: DE_ASSERT(0); // fallthrough
754                                                         case 0: genForUnif(); break;
755                                                         case 1: genForInf(); break;
756                                                         case 2: genForVar(); break;
757                                                         }
758                                                 }
759                                         }
760                                         break;
761                                 case 4:
762                                         genBreak();
763                                         break;
764                                 case 5:
765                                         genContinue();
766                                         break;
767                                 case 6:
768                                         genElect(false);
769                                         break;
770                                 case 7:
771                                         {
772                                                 deUint32 r2 = deRandom_getUint32(&rnd) % 5;
773                                                 if (r2 == 0 && callNesting == 0 && nesting < maxNesting - 2)
774                                                         genCall();
775                                                 else
776                                                         genReturn();
777                                                 break;
778                                         }
779                                 case 8:
780                                         {
781                                                 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
782                                                 if (loopNesting <= 3)
783                                                 {
784                                                         deUint32 r2 = deRandom_getUint32(&rnd) % 2;
785                                                         switch (r2)
786                                                         {
787                                                         default: DE_ASSERT(0); // fallthrough
788                                                         case 0: genDoWhileUnif(); break;
789                                                         case 1: genDoWhileInf(); break;
790                                                         }
791                                                 }
792                                         }
793                                         break;
794                                 case 9:
795                                         {
796                                                 deUint32 r2 = deRandom_getUint32(&rnd) % 4;
797                                                 switch (r2)
798                                                 {
799                                                 default:
800                                                         DE_ASSERT(0);
801                                                         // fallthrough
802                                                 case 0:
803                                                         genSwitchUnif();
804                                                         break;
805                                                 case 1:
806                                                         if (loopNesting > 0) {
807                                                                 genSwitchLoopCount();
808                                                                 break;
809                                                         }
810                                                         // fallthrough
811                                                 case 2:
812                                                         if (caseDef.testType != TT_MAXIMAL)
813                                                         {
814                                                                 // multicase doesn't have fully-defined behavior for MAXIMAL tests,
815                                                                 // but does for SUCF tests
816                                                                 genSwitchMulticase();
817                                                                 break;
818                                                         }
819                                                         // fallthrough
820                                                 case 3:
821                                                         genSwitchVar();
822                                                         break;
823                                                 }
824                                         }
825                                         break;
826                                 }
827                         }
828                         optBallot();
829                 }
830         }
831
832         void optBallot()
833         {
834                 // optionally insert ballots, stores, and noise. Ballots and stores are used to determine
835                 // correctness.
836                 if ((deRandom_getUint32(&rnd) % 100) < 20)
837                 {
838                         if (ops.size() < 2 ||
839                            !(ops[ops.size()-1].type == OP_BALLOT ||
840                                  (ops[ops.size()-1].type == OP_STORE && ops[ops.size()-2].type == OP_BALLOT)))
841                         {
842                                 // do a store along with each ballot, so we can correlate where
843                                 // the ballot came from
844                                 if (caseDef.testType != TT_MAXIMAL)
845                                         ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
846                                 ops.push_back({OP_BALLOT, 0});
847                         }
848                 }
849
850                 if ((deRandom_getUint32(&rnd) % 100) < 10)
851                 {
852                         if (ops.size() < 2 ||
853                            !(ops[ops.size()-1].type == OP_STORE ||
854                                  (ops[ops.size()-1].type == OP_BALLOT && ops[ops.size()-2].type == OP_STORE)))
855                         {
856                                 // SUCF does a store with every ballot. Don't bloat the code by adding more.
857                                 if (caseDef.testType == TT_MAXIMAL)
858                                         ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
859                         }
860                 }
861
862                 deUint32 r = deRandom_getUint32(&rnd) % 10000;
863                 if (r < 3)
864                         ops.push_back({OP_NOISE, 0});
865                 else if (r < 10)
866                         ops.push_back({OP_NOISE, 1});
867         }
868
869         void generateRandomProgram()
870         {
871                 do {
872                         ops.clear();
873                         while ((deInt32)ops.size() < minCount)
874                                 pickOP(1);
875
876                         // Retry until the program has some UCF results in it
877                         if (caseDef.isUCF())
878                         {
879                                 const deUint32 invocationStride = 128;
880                                 // Simulate for all subgroup sizes, to determine whether OP_BALLOTs are nonuniform
881                                 for (deInt32 subgroupSize = 4; subgroupSize <= 64; subgroupSize *= 2) {
882                                         simulate(true, subgroupSize, invocationStride, DE_NULL);
883                                 }
884                         }
885                 } while (caseDef.isUCF() && !hasUCF());
886         }
887
888         void printIndent(std::stringstream &css)
889         {
890                 for (deInt32 i = 0; i < indent; ++i)
891                         css << " ";
892         }
893
894         std::string genPartitionBallot()
895         {
896                 std::stringstream ss;
897                 ss << "subgroupBallot(true).xy";
898                 return ss.str();
899         }
900
901         void printBallot(std::stringstream *css)
902         {
903                 *css << "outputC.loc[gl_LocalInvocationIndex]++,";
904                 // When inside loop(s), use partitionBallot rather than subgroupBallot to compute
905                 // a ballot, to make sure the ballot is "diverged enough". Don't do this for
906                 // subgroup_uniform_control_flow, since we only validate results that must be fully
907                 // reconverged.
908                 if (loopNesting > 0 && caseDef.testType == TT_MAXIMAL)
909                 {
910                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = " << genPartitionBallot();
911                 }
912                 else if (caseDef.isElect())
913                 {
914                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = elect()";
915                 }
916                 else
917                 {
918                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = subgroupBallot(true).xy";
919                 }
920         }
921
922         void genCode(std::stringstream &functions, std::stringstream &main)
923         {
924                 std::stringstream *css = &main;
925                 indent = 4;
926                 loopNesting = 0;
927                 int funcNum = 0;
928                 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
929                 {
930                         switch (ops[i].type)
931                         {
932                         case OP_IF_MASK:
933                                 printIndent(*css);
934                                 if (ops[i].value == ~0ULL)
935                                 {
936                                         // This equality test will always succeed, since inputA.a[i] == i
937                                         int idx = deRandom_getUint32(&rnd) % 4;
938                                         *css << "if (inputA.a[" << idx << "] == " << idx << ") {\n";
939                                 }
940                                 else
941                                         *css << "if (testBit(uvec2(0x" << std::hex << (ops[i].value & 0xFFFFFFFF) << ", 0x" << (ops[i].value >> 32) << "), gl_SubgroupInvocationID)) {\n";
942
943                                 indent += 4;
944                                 break;
945                         case OP_IF_LOOPCOUNT:
946                                 printIndent(*css); *css << "if (gl_SubgroupInvocationID == loopIdx" << loopNesting - 1 << ") {\n";
947                                 indent += 4;
948                                 break;
949                         case OP_IF_LOCAL_INVOCATION_INDEX:
950                                 printIndent(*css); *css << "if (gl_LocalInvocationIndex >= inputA.a[0x" << std::hex << ops[i].value << "]) {\n";
951                                 indent += 4;
952                                 break;
953                         case OP_ELSE_MASK:
954                         case OP_ELSE_LOOPCOUNT:
955                         case OP_ELSE_LOCAL_INVOCATION_INDEX:
956                                 indent -= 4;
957                                 printIndent(*css); *css << "} else {\n";
958                                 indent += 4;
959                                 break;
960                         case OP_ENDIF:
961                                 indent -= 4;
962                                 printIndent(*css); *css << "}\n";
963                                 break;
964                         case OP_BALLOT:
965                                 printIndent(*css); printBallot(css); *css << ";\n";
966                                 break;
967                         case OP_STORE:
968                                 printIndent(*css); *css << "outputC.loc[gl_LocalInvocationIndex]++;\n";
969                                 printIndent(*css); *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = 0x" << std::hex << ops[i].value << ";\n";
970                                 break;
971                         case OP_BEGIN_FOR_UNIF:
972                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
973                                 printIndent(*css); *css << "         loopIdx" << loopNesting << " < inputA.a[" << ops[i].value << "];\n";
974                                 printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
975                                 indent += 4;
976                                 loopNesting++;
977                                 break;
978                         case OP_END_FOR_UNIF:
979                                 loopNesting--;
980                                 indent -= 4;
981                                 printIndent(*css); *css << "}\n";
982                                 break;
983                         case OP_BEGIN_DO_WHILE_UNIF:
984                                 printIndent(*css); *css << "{\n";
985                                 indent += 4;
986                                 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
987                                 printIndent(*css); *css << "do {\n";
988                                 indent += 4;
989                                 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
990                                 loopNesting++;
991                                 break;
992                         case OP_BEGIN_DO_WHILE_INF:
993                                 printIndent(*css); *css << "{\n";
994                                 indent += 4;
995                                 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
996                                 printIndent(*css); *css << "do {\n";
997                                 indent += 4;
998                                 loopNesting++;
999                                 break;
1000                         case OP_END_DO_WHILE_UNIF:
1001                                 loopNesting--;
1002                                 indent -= 4;
1003                                 printIndent(*css); *css << "} while (loopIdx" << loopNesting << " < inputA.a[" << ops[(deUint32)ops[i].value].value << "]);\n";
1004                                 indent -= 4;
1005                                 printIndent(*css); *css << "}\n";
1006                                 break;
1007                         case OP_END_DO_WHILE_INF:
1008                                 loopNesting--;
1009                                 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
1010                                 indent -= 4;
1011                                 printIndent(*css); *css << "} while (true);\n";
1012                                 indent -= 4;
1013                                 printIndent(*css); *css << "}\n";
1014                                 break;
1015                         case OP_BEGIN_FOR_VAR:
1016                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
1017                                 printIndent(*css); *css << "         loopIdx" << loopNesting << " < gl_SubgroupInvocationID + 1;\n";
1018                                 printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
1019                                 indent += 4;
1020                                 loopNesting++;
1021                                 break;
1022                         case OP_END_FOR_VAR:
1023                                 loopNesting--;
1024                                 indent -= 4;
1025                                 printIndent(*css); *css << "}\n";
1026                                 break;
1027                         case OP_BEGIN_FOR_INF:
1028                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;;loopIdx" << loopNesting << "++,";
1029                                 loopNesting++;
1030                                 printBallot(css);
1031                                 *css << ") {\n";
1032                                 indent += 4;
1033                                 break;
1034                         case OP_END_FOR_INF:
1035                                 loopNesting--;
1036                                 indent -= 4;
1037                                 printIndent(*css); *css << "}\n";
1038                                 break;
1039                         case OP_BREAK:
1040                                 printIndent(*css); *css << "break;\n";
1041                                 break;
1042                         case OP_CONTINUE:
1043                                 printIndent(*css); *css << "continue;\n";
1044                                 break;
1045                         case OP_ELECT:
1046                                 printIndent(*css); *css << "if (subgroupElect()) {\n";
1047                                 indent += 4;
1048                                 break;
1049                         case OP_RETURN:
1050                                 printIndent(*css); *css << "return;\n";
1051                                 break;
1052                         case OP_CALL_BEGIN:
1053                                 printIndent(*css); *css << "func" << funcNum << "(";
1054                                 for (deInt32 n = 0; n < loopNesting; ++n)
1055                                 {
1056                                         *css << "loopIdx" << n;
1057                                         if (n != loopNesting - 1)
1058                                                 *css << ", ";
1059                                 }
1060                                 *css << ");\n";
1061                                 css = &functions;
1062                                 printIndent(*css); *css << "void func" << funcNum << "(";
1063                                 for (deInt32 n = 0; n < loopNesting; ++n)
1064                                 {
1065                                         *css << "int loopIdx" << n;
1066                                         if (n != loopNesting - 1)
1067                                                 *css << ", ";
1068                                 }
1069                                 *css << ") {\n";
1070                                 indent += 4;
1071                                 funcNum++;
1072                                 break;
1073                         case OP_CALL_END:
1074                                 indent -= 4;
1075                                 printIndent(*css); *css << "}\n";
1076                                 css = &main;
1077                                 break;
1078                         case OP_NOISE:
1079                                 if (ops[i].value == 0)
1080                                 {
1081                                         printIndent(*css); *css << "while (!subgroupElect()) {}\n";
1082                                 }
1083                                 else
1084                                 {
1085                                         printIndent(*css); *css << "if (inputA.a[0] == 12345) {\n";
1086                                         indent += 4;
1087                                         printIndent(*css); *css << "while (true) {\n";
1088                                         indent += 4;
1089                                         printIndent(*css); printBallot(css); *css << ";\n";
1090                                         indent -= 4;
1091                                         printIndent(*css); *css << "}\n";
1092                                         indent -= 4;
1093                                         printIndent(*css); *css << "}\n";
1094                                 }
1095                                 break;
1096                         case OP_SWITCH_UNIF_BEGIN:
1097                                 printIndent(*css); *css << "switch (inputA.a[" << ops[i].value << "]) {\n";
1098                                 indent += 4;
1099                                 break;
1100                         case OP_SWITCH_VAR_BEGIN:
1101                                 printIndent(*css); *css << "switch (gl_SubgroupInvocationID & 3) {\n";
1102                                 indent += 4;
1103                                 break;
1104                         case OP_SWITCH_LOOP_COUNT_BEGIN:
1105                                 printIndent(*css); *css << "switch (loopIdx" << ops[i].value << ") {\n";
1106                                 indent += 4;
1107                                 break;
1108                         case OP_SWITCH_END:
1109                                 indent -= 4;
1110                                 printIndent(*css); *css << "}\n";
1111                                 break;
1112                         case OP_CASE_MASK_BEGIN:
1113                                 for (deInt32 b = 0; b < 32; ++b)
1114                                 {
1115                                         if ((1u<<b) & ops[i].caseValue)
1116                                         {
1117                                                 printIndent(*css); *css << "case " << b << ":\n";
1118                                         }
1119                                 }
1120                                 printIndent(*css); *css << "{\n";
1121                                 indent += 4;
1122                                 break;
1123                         case OP_CASE_LOOP_COUNT_BEGIN:
1124                                 if (ops[i].caseValue == 0xFFFFFFFF)
1125                                 {
1126                                         printIndent(*css); *css << "default: {\n";
1127                                 }
1128                                 else
1129                                 {
1130                                         printIndent(*css); *css << "case " << ops[i].caseValue << ": {\n";
1131                                 }
1132                                 indent += 4;
1133                                 break;
1134                         case OP_CASE_END:
1135                                 printIndent(*css); *css << "break;\n";
1136                                 indent -= 4;
1137                                 printIndent(*css); *css << "}\n";
1138                                 break;
1139                         default:
1140                                 DE_ASSERT(0);
1141                                 break;
1142                         }
1143                 }
1144         }
1145
1146         // Simulate execution of the program. If countOnly is true, just return
1147         // the max number of outputs written. If it's false, store out the result
1148         // values to ref
1149         deUint32 simulate(bool countOnly, deUint32 subgroupSize, deUint32 invocationStride, deUint64 *ref)
1150         {
1151                 // State of the subgroup at each level of nesting
1152                 struct SubgroupState
1153                 {
1154                         // Currently executing
1155                         bitset128 activeMask;
1156                         // Have executed a continue instruction in this loop
1157                         bitset128 continueMask;
1158                         // index of the current if test or loop header
1159                         deUint32 header;
1160                         // number of loop iterations performed
1161                         deUint32 tripCount;
1162                         // is this nesting a loop?
1163                         deUint32 isLoop;
1164                         // is this nesting a function call?
1165                         deUint32 isCall;
1166                         // is this nesting a switch?
1167                         deUint32 isSwitch;
1168                 };
1169                 SubgroupState stateStack[10];
1170                 deMemset(&stateStack, 0, sizeof(stateStack));
1171
1172                 const deUint64 fullSubgroupMask = subgroupSizeToMask(subgroupSize);
1173
1174                 // Per-invocation output location counters
1175                 deUint32 outLoc[128] = {0};
1176
1177                 nesting = 0;
1178                 loopNesting = 0;
1179                 stateStack[nesting].activeMask = ~bitset128(); // initialized to ~0
1180
1181                 deInt32 i = 0;
1182                 while (i < (deInt32)ops.size())
1183                 {
1184                         switch (ops[i].type)
1185                         {
1186                         case OP_BALLOT:
1187
1188                                 // Flag that this ballot is workgroup-nonuniform
1189                                 if (caseDef.isWUCF() && stateStack[nesting].activeMask.any() && !stateStack[nesting].activeMask.all())
1190                                         ops[i].caseValue = 1;
1191
1192                                 if (caseDef.isSUCF())
1193                                 {
1194                                         for (deUint32 id = 0; id < 128; id += subgroupSize)
1195                                         {
1196                                                 deUint64 subgroupMask = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1197                                                 // Flag that this ballot is subgroup-nonuniform
1198                                                 if (subgroupMask != 0 && subgroupMask != fullSubgroupMask)
1199                                                         ops[i].caseValue = 1;
1200                                         }
1201                                 }
1202
1203                                 for (deUint32 id = 0; id < 128; ++id)
1204                                 {
1205                                         if (stateStack[nesting].activeMask.test(id))
1206                                         {
1207                                                 if (countOnly)
1208                                                 {
1209                                                         outLoc[id]++;
1210                                                 }
1211                                                 else
1212                                                 {
1213                                                         if (ops[i].caseValue)
1214                                                         {
1215                                                                 // Emit a magic value to indicate that we shouldn't validate this ballot
1216                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id);
1217                                                         }
1218                                                         else
1219                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1220                                                 }
1221                                         }
1222                                 }
1223                                 break;
1224                         case OP_STORE:
1225                                 for (deUint32 id = 0; id < 128; ++id)
1226                                 {
1227                                         if (stateStack[nesting].activeMask.test(id))
1228                                         {
1229                                                 if (countOnly)
1230                                                         outLoc[id]++;
1231                                                 else
1232                                                         ref[(outLoc[id]++)*invocationStride + id] = ops[i].value;
1233                                         }
1234                                 }
1235                                 break;
1236                         case OP_IF_MASK:
1237                                 nesting++;
1238                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1239                                 stateStack[nesting].header = i;
1240                                 stateStack[nesting].isLoop = 0;
1241                                 stateStack[nesting].isSwitch = 0;
1242                                 break;
1243                         case OP_ELSE_MASK:
1244                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64(ops[stateStack[nesting].header].value, subgroupSize);
1245                                 break;
1246                         case OP_IF_LOOPCOUNT:
1247                                 {
1248                                         deUint32 n = nesting;
1249                                         while (!stateStack[n].isLoop)
1250                                                 n--;
1251
1252                                         nesting++;
1253                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1254                                         stateStack[nesting].header = i;
1255                                         stateStack[nesting].isLoop = 0;
1256                                         stateStack[nesting].isSwitch = 0;
1257                                         break;
1258                                 }
1259                         case OP_ELSE_LOOPCOUNT:
1260                                 {
1261                                         deUint32 n = nesting;
1262                                         while (!stateStack[n].isLoop)
1263                                                 n--;
1264
1265                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1266                                         break;
1267                                 }
1268                         case OP_IF_LOCAL_INVOCATION_INDEX:
1269                                 {
1270                                         // all bits >= N
1271                                         bitset128 mask(0);
1272                                         for (deInt32 j = (deInt32)ops[i].value; j < 128; ++j)
1273                                                 mask.set(j);
1274
1275                                         nesting++;
1276                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1277                                         stateStack[nesting].header = i;
1278                                         stateStack[nesting].isLoop = 0;
1279                                         stateStack[nesting].isSwitch = 0;
1280                                         break;
1281                                 }
1282                         case OP_ELSE_LOCAL_INVOCATION_INDEX:
1283                                 {
1284                                         // all bits < N
1285                                         bitset128 mask(0);
1286                                         for (deInt32 j = 0; j < (deInt32)ops[i].value; ++j)
1287                                                 mask.set(j);
1288
1289                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1290                                         break;
1291                                 }
1292                         case OP_ENDIF:
1293                                 nesting--;
1294                                 break;
1295                         case OP_BEGIN_FOR_UNIF:
1296                                 // XXX TODO: We don't handle a for loop with zero iterations
1297                                 nesting++;
1298                                 loopNesting++;
1299                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1300                                 stateStack[nesting].header = i;
1301                                 stateStack[nesting].tripCount = 0;
1302                                 stateStack[nesting].isLoop = 1;
1303                                 stateStack[nesting].isSwitch = 0;
1304                                 stateStack[nesting].continueMask = 0;
1305                                 break;
1306                         case OP_END_FOR_UNIF:
1307                                 stateStack[nesting].tripCount++;
1308                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1309                                 stateStack[nesting].continueMask = 0;
1310                                 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1311                                         stateStack[nesting].activeMask.any())
1312                                 {
1313                                         i = stateStack[nesting].header+1;
1314                                         continue;
1315                                 }
1316                                 else
1317                                 {
1318                                         loopNesting--;
1319                                         nesting--;
1320                                 }
1321                                 break;
1322                         case OP_BEGIN_DO_WHILE_UNIF:
1323                                 // XXX TODO: We don't handle a for loop with zero iterations
1324                                 nesting++;
1325                                 loopNesting++;
1326                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1327                                 stateStack[nesting].header = i;
1328                                 stateStack[nesting].tripCount = 1;
1329                                 stateStack[nesting].isLoop = 1;
1330                                 stateStack[nesting].isSwitch = 0;
1331                                 stateStack[nesting].continueMask = 0;
1332                                 break;
1333                         case OP_END_DO_WHILE_UNIF:
1334                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1335                                 stateStack[nesting].continueMask = 0;
1336                                 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1337                                         stateStack[nesting].activeMask.any())
1338                                 {
1339                                         i = stateStack[nesting].header+1;
1340                                         stateStack[nesting].tripCount++;
1341                                         continue;
1342                                 }
1343                                 else
1344                                 {
1345                                         loopNesting--;
1346                                         nesting--;
1347                                 }
1348                                 break;
1349                         case OP_BEGIN_FOR_VAR:
1350                                 // XXX TODO: We don't handle a for loop with zero iterations
1351                                 nesting++;
1352                                 loopNesting++;
1353                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1354                                 stateStack[nesting].header = i;
1355                                 stateStack[nesting].tripCount = 0;
1356                                 stateStack[nesting].isLoop = 1;
1357                                 stateStack[nesting].isSwitch = 0;
1358                                 stateStack[nesting].continueMask = 0;
1359                                 break;
1360                         case OP_END_FOR_VAR:
1361                                 stateStack[nesting].tripCount++;
1362                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1363                                 stateStack[nesting].continueMask = 0;
1364                                 stateStack[nesting].activeMask &= bitsetFromU64(stateStack[nesting].tripCount == subgroupSize ? 0 : ~((1ULL << (stateStack[nesting].tripCount)) - 1), subgroupSize);
1365                                 if (stateStack[nesting].activeMask.any())
1366                                 {
1367                                         i = stateStack[nesting].header+1;
1368                                         continue;
1369                                 }
1370                                 else
1371                                 {
1372                                         loopNesting--;
1373                                         nesting--;
1374                                 }
1375                                 break;
1376                         case OP_BEGIN_FOR_INF:
1377                         case OP_BEGIN_DO_WHILE_INF:
1378                                 nesting++;
1379                                 loopNesting++;
1380                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1381                                 stateStack[nesting].header = i;
1382                                 stateStack[nesting].tripCount = 0;
1383                                 stateStack[nesting].isLoop = 1;
1384                                 stateStack[nesting].isSwitch = 0;
1385                                 stateStack[nesting].continueMask = 0;
1386                                 break;
1387                         case OP_END_FOR_INF:
1388                                 stateStack[nesting].tripCount++;
1389                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1390                                 stateStack[nesting].continueMask = 0;
1391                                 if (stateStack[nesting].activeMask.any())
1392                                 {
1393                                         // output expected OP_BALLOT values
1394                                         for (deUint32 id = 0; id < 128; ++id)
1395                                         {
1396                                                 if (stateStack[nesting].activeMask.test(id))
1397                                                 {
1398                                                         if (countOnly)
1399                                                                 outLoc[id]++;
1400                                                         else
1401                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1402                                                 }
1403                                         }
1404
1405                                         i = stateStack[nesting].header+1;
1406                                         continue;
1407                                 }
1408                                 else
1409                                 {
1410                                         loopNesting--;
1411                                         nesting--;
1412                                 }
1413                                 break;
1414                         case OP_END_DO_WHILE_INF:
1415                                 stateStack[nesting].tripCount++;
1416                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1417                                 stateStack[nesting].continueMask = 0;
1418                                 if (stateStack[nesting].activeMask.any())
1419                                 {
1420                                         i = stateStack[nesting].header+1;
1421                                         continue;
1422                                 }
1423                                 else
1424                                 {
1425                                         loopNesting--;
1426                                         nesting--;
1427                                 }
1428                                 break;
1429                         case OP_BREAK:
1430                                 {
1431                                         deUint32 n = nesting;
1432                                         bitset128 mask = stateStack[nesting].activeMask;
1433                                         while (true)
1434                                         {
1435                                                 stateStack[n].activeMask &= ~mask;
1436                                                 if (stateStack[n].isLoop || stateStack[n].isSwitch)
1437                                                         break;
1438
1439                                                 n--;
1440                                         }
1441                                 }
1442                                 break;
1443                         case OP_CONTINUE:
1444                                 {
1445                                         deUint32 n = nesting;
1446                                         bitset128 mask = stateStack[nesting].activeMask;
1447                                         while (true)
1448                                         {
1449                                                 stateStack[n].activeMask &= ~mask;
1450                                                 if (stateStack[n].isLoop)
1451                                                 {
1452                                                         stateStack[n].continueMask |= mask;
1453                                                         break;
1454                                                 }
1455                                                 n--;
1456                                         }
1457                                 }
1458                                 break;
1459                         case OP_ELECT:
1460                                 {
1461                                         nesting++;
1462                                         stateStack[nesting].activeMask = bitsetElect(stateStack[nesting-1].activeMask, subgroupSize);
1463                                         stateStack[nesting].header = i;
1464                                         stateStack[nesting].isLoop = 0;
1465                                         stateStack[nesting].isSwitch = 0;
1466                                 }
1467                                 break;
1468                         case OP_RETURN:
1469                                 {
1470                                         bitset128 mask = stateStack[nesting].activeMask;
1471                                         for (deInt32 n = nesting; n >= 0; --n)
1472                                         {
1473                                                 stateStack[n].activeMask &= ~mask;
1474                                                 if (stateStack[n].isCall)
1475                                                         break;
1476                                         }
1477                                 }
1478                                 break;
1479
1480                         case OP_CALL_BEGIN:
1481                                 nesting++;
1482                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1483                                 stateStack[nesting].isLoop = 0;
1484                                 stateStack[nesting].isSwitch = 0;
1485                                 stateStack[nesting].isCall = 1;
1486                                 break;
1487                         case OP_CALL_END:
1488                                 stateStack[nesting].isCall = 0;
1489                                 nesting--;
1490                                 break;
1491                         case OP_NOISE:
1492                                 break;
1493
1494                         case OP_SWITCH_UNIF_BEGIN:
1495                         case OP_SWITCH_VAR_BEGIN:
1496                         case OP_SWITCH_LOOP_COUNT_BEGIN:
1497                                 nesting++;
1498                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1499                                 stateStack[nesting].header = i;
1500                                 stateStack[nesting].isLoop = 0;
1501                                 stateStack[nesting].isSwitch = 1;
1502                                 break;
1503                         case OP_SWITCH_END:
1504                                 nesting--;
1505                                 break;
1506                         case OP_CASE_MASK_BEGIN:
1507                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1508                                 break;
1509                         case OP_CASE_LOOP_COUNT_BEGIN:
1510                                 {
1511                                         deUint32 n = nesting;
1512                                         deUint32 l = loopNesting;
1513
1514                                         while (true)
1515                                         {
1516                                                 if (stateStack[n].isLoop)
1517                                                 {
1518                                                         l--;
1519                                                         if (l == ops[stateStack[nesting].header].value)
1520                                                                 break;
1521                                                 }
1522                                                 n--;
1523                                         }
1524
1525                                         if ((1ULL << stateStack[n].tripCount) & ops[i].value)
1526                                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1527                                         else
1528                                                 stateStack[nesting].activeMask = 0;
1529                                         break;
1530                                 }
1531                         case OP_CASE_END:
1532                                 break;
1533
1534                         default:
1535                                 DE_ASSERT(0);
1536                                 break;
1537                         }
1538                         i++;
1539                 }
1540                 deUint32 maxLoc = 0;
1541                 for (deUint32 id = 0; id < ARRAYSIZE(outLoc); ++id)
1542                         maxLoc = de::max(maxLoc, outLoc[id]);
1543
1544                 return maxLoc;
1545         }
1546
1547         bool hasUCF() const
1548         {
1549                 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
1550                 {
1551                         if (ops[i].type == OP_BALLOT && ops[i].caseValue == 0)
1552                                 return true;
1553                 }
1554                 return false;
1555         }
1556 };
1557
1558 void ReconvergenceTestCase::initPrograms (SourceCollections& programCollection) const
1559 {
1560         RandomProgram program(m_data);
1561         program.generateRandomProgram();
1562
1563         std::stringstream css;
1564         css << "#version 450 core\n";
1565         css << "#extension GL_KHR_shader_subgroup_ballot : enable\n";
1566         css << "#extension GL_KHR_shader_subgroup_vote : enable\n";
1567         css << "#extension GL_NV_shader_subgroup_partitioned : enable\n";
1568         css << "#extension GL_EXT_subgroup_uniform_control_flow : enable\n";
1569         css << "layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n";
1570         css << "layout(set=0, binding=0) coherent buffer InputA { uint a[]; } inputA;\n";
1571         css << "layout(set=0, binding=1) coherent buffer OutputB { uvec2 b[]; } outputB;\n";
1572         css << "layout(set=0, binding=2) coherent buffer OutputC { uint loc[]; } outputC;\n";
1573         css << "layout(push_constant) uniform PC {\n"
1574                         "   // set to the real stride when writing out ballots, or zero when just counting\n"
1575                         "   int invocationStride;\n"
1576                         "};\n";
1577         css << "int outLoc = 0;\n";
1578
1579         css << "bool testBit(uvec2 mask, uint bit) { return (bit < 32) ? ((mask.x >> bit) & 1) != 0 : ((mask.y >> (bit-32)) & 1) != 0; }\n";
1580
1581         css << "uint elect() { return int(subgroupElect()) + 1; }\n";
1582
1583         std::stringstream functions, main;
1584         program.genCode(functions, main);
1585
1586         css << functions.str() << "\n\n";
1587
1588         css <<
1589                 "void main()\n"
1590                 << (m_data.isSUCF() ? "[[subgroup_uniform_control_flow]]\n" : "") <<
1591                 "{\n";
1592
1593         css << main.str() << "\n\n";
1594
1595         css << "}\n";
1596
1597         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
1598
1599         programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
1600 }
1601
1602 TestInstance* ReconvergenceTestCase::createInstance (Context& context) const
1603 {
1604         return new ReconvergenceTestInstance(context, m_data);
1605 }
1606
1607 tcu::TestStatus ReconvergenceTestInstance::iterate (void)
1608 {
1609         const DeviceInterface&  vk                                              = m_context.getDeviceInterface();
1610         const VkDevice                  device                                  = m_context.getDevice();
1611         Allocator&                              allocator                               = m_context.getDefaultAllocator();
1612         tcu::TestLog&                   log                                             = m_context.getTestContext().getLog();
1613
1614         deRandom rnd;
1615         deRandom_init(&rnd, m_data.seed);
1616
1617         vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
1618         deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
1619         subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
1620
1621         vk::VkPhysicalDeviceProperties2 properties2;
1622         deMemset(&properties2, 0, sizeof(properties2));
1623         properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
1624         properties2.pNext = &subgroupProperties;
1625
1626         m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
1627
1628         const deUint32 subgroupSize = subgroupProperties.subgroupSize;
1629         const deUint32 invocationStride = 128;
1630
1631         if (subgroupSize > 64)
1632                 TCU_THROW(TestError, "Subgroup size greater than 64 not handled.");
1633
1634         RandomProgram program(m_data);
1635         program.generateRandomProgram();
1636
1637         deUint32 maxLoc = program.simulate(true, subgroupSize, invocationStride, DE_NULL);
1638
1639         // maxLoc is per-invocation. Add one (to make sure no additional writes are done) and multiply by
1640         // the number of invocations
1641         maxLoc++;
1642         maxLoc *= invocationStride;
1643
1644         // buffer[0] is an input filled with a[i] == i
1645         // buffer[1] is the output
1646         // buffer[2] is the location counts
1647         de::MovePtr<BufferWithMemory> buffers[3];
1648         vk::VkDescriptorBufferInfo bufferDescriptors[3];
1649
1650         VkDeviceSize sizes[3] =
1651         {
1652                 128 * sizeof(deUint32),
1653                 maxLoc * sizeof(deUint64),
1654                 invocationStride * sizeof(deUint32),
1655         };
1656
1657         for (deUint32 i = 0; i < 3; ++i)
1658         {
1659                 try
1660                 {
1661                         buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1662                                 vk, device, allocator, makeBufferCreateInfo(sizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1663                                 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1664                 }
1665                 catch(tcu::ResourceError&)
1666                 {
1667                         // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1668                         return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[i]) + " bytes");
1669                 }
1670                 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, sizes[i]);
1671         }
1672
1673         deUint32 *ptrs[3];
1674         for (deUint32 i = 0; i < 3; ++i)
1675         {
1676                 ptrs[i] = (deUint32 *)buffers[i]->getAllocation().getHostPtr();
1677         }
1678         for (deUint32 i = 0; i < sizes[0] / sizeof(deUint32); ++i)
1679         {
1680                 ptrs[0][i] = i;
1681         }
1682         deMemset(ptrs[1], 0, (size_t)sizes[1]);
1683         deMemset(ptrs[2], 0, (size_t)sizes[2]);
1684
1685         vk::DescriptorSetLayoutBuilder layoutBuilder;
1686
1687         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1688         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1689         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1690
1691         vk::Unique<vk::VkDescriptorSetLayout>   descriptorSetLayout(layoutBuilder.build(vk, device));
1692
1693         vk::Unique<vk::VkDescriptorPool>                descriptorPool(vk::DescriptorPoolBuilder()
1694                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u)
1695                 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1696         vk::Unique<vk::VkDescriptorSet>                 descriptorSet           (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1697
1698         const deUint32 specData[1] =
1699         {
1700                 invocationStride,
1701         };
1702         const vk::VkSpecializationMapEntry entries[1] =
1703         {
1704                 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1705         };
1706         const vk::VkSpecializationInfo specInfo =
1707         {
1708                 1,                                              // mapEntryCount
1709                 entries,                                // pMapEntries
1710                 sizeof(specData),               // dataSize
1711                 specData                                // pData
1712         };
1713
1714         const VkPushConstantRange                               pushConstantRange                               =
1715         {
1716                 allShaderStages,                                                                                        // VkShaderStageFlags                                   stageFlags;
1717                 0u,                                                                                                                     // deUint32                                                             offset;
1718                 sizeof(deInt32)                                                                                         // deUint32                                                             size;
1719         };
1720
1721         const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
1722         {
1723                 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,                          // sType
1724                 DE_NULL,                                                                                                        // pNext
1725                 (VkPipelineLayoutCreateFlags)0,
1726                 1,                                                                                                                      // setLayoutCount
1727                 &descriptorSetLayout.get(),                                                                     // pSetLayouts
1728                 1u,                                                                                                                     // pushConstantRangeCount
1729                 &pushConstantRange,                                                                                     // pPushConstantRanges
1730         };
1731
1732         Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
1733
1734         VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1735
1736         flushAlloc(vk, device, buffers[0]->getAllocation());
1737         flushAlloc(vk, device, buffers[1]->getAllocation());
1738         flushAlloc(vk, device, buffers[2]->getAllocation());
1739
1740         const VkBool32 computeFullSubgroups = subgroupProperties.subgroupSize <= 64 &&
1741                                                                                   m_context.getSubgroupSizeControlFeaturesEXT().computeFullSubgroups;
1742
1743         const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo =
1744         {
1745                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType              sType;
1746                 DE_NULL,                                                                                                                                                // void*                                pNext;
1747                 subgroupProperties.subgroupSize                                                                                                 // uint32_t                             requiredSubgroupSize;
1748         };
1749
1750         const void *shaderPNext = computeFullSubgroups ? &subgroupSizeCreateInfo : DE_NULL;
1751         VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags =
1752                 (VkPipelineShaderStageCreateFlags)(computeFullSubgroups ? VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT : 0);
1753
1754         const Unique<VkShaderModule>                    shader                                          (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
1755         const VkPipelineShaderStageCreateInfo   shaderCreateInfo =
1756         {
1757                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1758                 shaderPNext,
1759                 pipelineShaderStageCreateFlags,
1760                 VK_SHADER_STAGE_COMPUTE_BIT,                                                            // stage
1761                 *shader,                                                                                                        // shader
1762                 "main",
1763                 &specInfo,                                                                                                      // pSpecializationInfo
1764         };
1765
1766         const VkComputePipelineCreateInfo               pipelineCreateInfo =
1767         {
1768                 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1769                 DE_NULL,
1770                 0u,                                                                                                                     // flags
1771                 shaderCreateInfo,                                                                                       // cs
1772                 *pipelineLayout,                                                                                        // layout
1773                 (vk::VkPipeline)0,                                                                                      // basePipelineHandle
1774                 0u,                                                                                                                     // basePipelineIndex
1775         };
1776         Move<VkPipeline> pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
1777
1778         const VkQueue                                   queue                                   = m_context.getUniversalQueue();
1779         Move<VkCommandPool>                             cmdPool                                 = createCommandPool(vk, device, vk::VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, m_context.getUniversalQueueFamilyIndex());
1780         Move<VkCommandBuffer>                   cmdBuffer                               = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1781
1782
1783         vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1784         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1785                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1786         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1787                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1788         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1789                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1790         setUpdateBuilder.update(vk, device);
1791
1792         // compute "maxLoc", the maximum number of locations written
1793         beginCommandBuffer(vk, *cmdBuffer, 0u);
1794
1795         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1796         vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1797
1798         deInt32 pcinvocationStride = 0;
1799         vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1800
1801         vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1802
1803         endCommandBuffer(vk, *cmdBuffer);
1804
1805         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1806
1807         invalidateAlloc(vk, device, buffers[1]->getAllocation());
1808         invalidateAlloc(vk, device, buffers[2]->getAllocation());
1809
1810         // Clear any writes to buffer[1] during the counting pass
1811         deMemset(ptrs[1], 0, invocationStride * sizeof(deUint64));
1812
1813         // Take the max over all invocations. Add one (to make sure no additional writes are done) and multiply by
1814         // the number of invocations
1815         deUint32 newMaxLoc = 0;
1816         for (deUint32 id = 0; id < invocationStride; ++id)
1817                 newMaxLoc = de::max(newMaxLoc, ptrs[2][id]);
1818         newMaxLoc++;
1819         newMaxLoc *= invocationStride;
1820
1821         // If we need more space, reallocate buffers[1]
1822         if (newMaxLoc > maxLoc)
1823         {
1824                 maxLoc = newMaxLoc;
1825                 sizes[1] = maxLoc * sizeof(deUint64);
1826
1827                 try
1828                 {
1829                         buffers[1] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1830                                 vk, device, allocator, makeBufferCreateInfo(sizes[1], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1831                                 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1832                 }
1833                 catch(tcu::ResourceError&)
1834                 {
1835                         // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1836                         return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[1]) + " bytes");
1837                 }
1838                 bufferDescriptors[1] = makeDescriptorBufferInfo(**buffers[1], 0, sizes[1]);
1839                 ptrs[1] = (deUint32 *)buffers[1]->getAllocation().getHostPtr();
1840                 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1841
1842                 vk::DescriptorSetUpdateBuilder setUpdateBuilder2;
1843                 setUpdateBuilder2.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1844                         VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1845                 setUpdateBuilder2.update(vk, device);
1846         }
1847
1848         flushAlloc(vk, device, buffers[1]->getAllocation());
1849
1850         // run the actual shader
1851         beginCommandBuffer(vk, *cmdBuffer, 0u);
1852
1853         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1854         vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1855
1856         pcinvocationStride = invocationStride;
1857         vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1858
1859         vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1860
1861         endCommandBuffer(vk, *cmdBuffer);
1862
1863         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1864
1865         invalidateAlloc(vk, device, buffers[1]->getAllocation());
1866
1867         qpTestResult res = QP_TEST_RESULT_PASS;
1868
1869         // Simulate execution on the CPU, and compare against the GPU result
1870         deUint64 *ref = new deUint64 [maxLoc];
1871         deMemset(ref, 0, maxLoc*sizeof(deUint64));
1872         program.simulate(false, subgroupSize, invocationStride, ref);
1873
1874         const deUint64 *result = (const deUint64 *)ptrs[1];
1875
1876         if (m_data.testType == TT_MAXIMAL)
1877         {
1878                 // With maximal reconvergence, we should expect the output to exactly match
1879                 // the reference.
1880                 for (deUint32 i = 0; i < maxLoc; ++i)
1881                 {
1882                         if (result[i] != ref[i])
1883                         {
1884                                 log << tcu::TestLog::Message << "first mismatch at " << i << tcu::TestLog::EndMessage;
1885                                 res = QP_TEST_RESULT_FAIL;
1886                                 break;
1887                         }
1888                 }
1889
1890                 if (res != QP_TEST_RESULT_PASS)
1891                 {
1892                         for (deUint32 i = 0; i < maxLoc; ++i)
1893                         {
1894                                 // This log can be large and slow, ifdef it out by default
1895 #if 0
1896                                 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (result[i] != ref[i] ? " different" : "") << tcu::TestLog::EndMessage;
1897 #endif
1898                         }
1899                 }
1900         }
1901         else
1902         {
1903                 deUint64 fullMask = subgroupSizeToMask(subgroupSize);
1904                 // For subgroup_uniform_control_flow, we expect any fully converged outputs in the reference
1905                 // to have a corresponding fully converged output in the result. So walk through each lane's
1906                 // results, and for each reference value of fullMask, find a corresponding result value of
1907                 // fullMask where the previous value (OP_STORE) matches. That means these came from the same
1908                 // source location.
1909                 vector<deUint32> firstFail(invocationStride, 0);
1910                 for (deUint32 lane = 0; lane < invocationStride; ++lane)
1911                 {
1912                         deUint32 resLoc = lane + invocationStride, refLoc = lane + invocationStride;
1913                         while (refLoc < maxLoc)
1914                         {
1915                                 while (refLoc < maxLoc && ref[refLoc] != fullMask)
1916                                         refLoc += invocationStride;
1917                                 if (refLoc >= maxLoc)
1918                                         break;
1919
1920                                 // For TT_SUCF_ELECT, when the reference result has a full mask, we expect lane 0 to be elected
1921                                 // (a value of 2) and all other lanes to be not elected (a value of 1). For TT_SUCF_BALLOT, we
1922                                 // expect a full mask. Search until we find the expected result with a matching store value in
1923                                 // the previous result.
1924                                 deUint64 expectedResult = m_data.isElect() ? ((lane % subgroupSize) == 0 ? 2 : 1)
1925                                                                                                                          : fullMask;
1926
1927                                 while (resLoc < maxLoc && !(result[resLoc] == expectedResult && result[resLoc-invocationStride] == ref[refLoc-invocationStride]))
1928                                         resLoc += invocationStride;
1929
1930                                 // If we didn't find this output in the result, flag it as an error.
1931                                 if (resLoc >= maxLoc)
1932                                 {
1933                                         firstFail[lane] = refLoc;
1934                                         log << tcu::TestLog::Message << "lane " << lane << " first mismatch at " << firstFail[lane] << tcu::TestLog::EndMessage;
1935                                         res = QP_TEST_RESULT_FAIL;
1936                                         break;
1937                                 }
1938                                 refLoc += invocationStride;
1939                                 resLoc += invocationStride;
1940                         }
1941                 }
1942
1943                 if (res != QP_TEST_RESULT_PASS)
1944                 {
1945                         for (deUint32 i = 0; i < maxLoc; ++i)
1946                         {
1947                                 // This log can be large and slow, ifdef it out by default
1948 #if 0
1949                                 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (i == firstFail[i%invocationStride] ? " first fail" : "") << tcu::TestLog::EndMessage;
1950 #endif
1951                         }
1952                 }
1953         }
1954
1955         delete []ref;
1956
1957         return tcu::TestStatus(res, qpGetTestResultName(res));
1958 }
1959
1960 }       // anonymous
1961
1962 tcu::TestCaseGroup*     createTests (tcu::TestContext& testCtx, bool createExperimental)
1963 {
1964         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1965                         testCtx, "reconvergence", "reconvergence tests"));
1966
1967         typedef struct
1968         {
1969                 deUint32                                value;
1970                 const char*                             name;
1971                 const char*                             description;
1972         } TestGroupCase;
1973
1974         TestGroupCase ttCases[] =
1975         {
1976                 { TT_SUCF_ELECT,                                "subgroup_uniform_control_flow_elect",  "subgroup_uniform_control_flow_elect"           },
1977                 { TT_SUCF_BALLOT,                               "subgroup_uniform_control_flow_ballot", "subgroup_uniform_control_flow_ballot"          },
1978                 { TT_WUCF_ELECT,                                "workgroup_uniform_control_flow_elect", "workgroup_uniform_control_flow_elect"          },
1979                 { TT_WUCF_BALLOT,                               "workgroup_uniform_control_flow_ballot","workgroup_uniform_control_flow_ballot"         },
1980                 { TT_MAXIMAL,                                   "maximal",                                                              "maximal"                                                                       },
1981         };
1982
1983         for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1984         {
1985                 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1986                 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(testCtx, "compute", ""));
1987
1988                 for (deUint32 nNdx = 2; nNdx <= 6; nNdx++)
1989                 {
1990                         de::MovePtr<tcu::TestCaseGroup> nestGroup(new tcu::TestCaseGroup(testCtx, ("nesting" + de::toString(nNdx)).c_str(), ""));
1991
1992                         deUint32 seed = 0;
1993
1994                         for (int sNdx = 0; sNdx < 8; sNdx++)
1995                         {
1996                                 de::MovePtr<tcu::TestCaseGroup> seedGroup(new tcu::TestCaseGroup(testCtx, de::toString(sNdx).c_str(), ""));
1997
1998                                 deUint32 numTests = 0;
1999                                 switch (nNdx)
2000                                 {
2001                                 default:
2002                                         DE_ASSERT(0);
2003                                         // fallthrough
2004                                 case 2:
2005                                 case 3:
2006                                 case 4:
2007                                         numTests = 250;
2008                                         break;
2009                                 case 5:
2010                                         numTests = 100;
2011                                         break;
2012                                 case 6:
2013                                         numTests = 50;
2014                                         break;
2015                                 }
2016
2017                                 if (ttCases[ttNdx].value != TT_MAXIMAL)
2018                                 {
2019                                         if (nNdx >= 5)
2020                                                 continue;
2021                                 }
2022
2023                                 for (deUint32 ndx = 0; ndx < numTests; ndx++)
2024                                 {
2025                                         CaseDef c =
2026                                         {
2027                                                 (TestType)ttCases[ttNdx].value,         // TestType testType;
2028                                                 nNdx,                                                           // deUint32 maxNesting;
2029                                                 seed,                                                           // deUint32 seed;
2030                                         };
2031                                         seed++;
2032
2033                                         bool isExperimentalTest = !c.isUCF() || (ndx >= numTests / 5);
2034
2035                                         if (createExperimental == isExperimentalTest)
2036                                                 seedGroup->addChild(new ReconvergenceTestCase(testCtx, de::toString(ndx).c_str(), "", c));
2037                                 }
2038                                 if (!seedGroup->empty())
2039                                         nestGroup->addChild(seedGroup.release());
2040                         }
2041                         if (!nestGroup->empty())
2042                                 computeGroup->addChild(nestGroup.release());
2043                 }
2044                 if (!computeGroup->empty())
2045                 {
2046                         ttGroup->addChild(computeGroup.release());
2047                         group->addChild(ttGroup.release());
2048                 }
2049         }
2050         return group.release();
2051 }
2052
2053 }       // Reconvergence
2054 }       // vkt