Test behaviour of color write enable with colorWriteMask
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / reconvergence / vktReconvergenceTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2020 NVIDIA Corporation
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *        http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Vulkan Reconvergence tests
23  *//*--------------------------------------------------------------------*/
24
25 #include "vktReconvergenceTests.hpp"
26
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47
48 #include <bitset>
49 #include <string>
50 #include <sstream>
51 #include <set>
52 #include <vector>
53
54 namespace vkt
55 {
56 namespace Reconvergence
57 {
58 namespace
59 {
60 using namespace vk;
61 using namespace std;
62
63 #define ARRAYSIZE(x) (sizeof(x) / sizeof(x[0]))
64
65 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
66
67 typedef enum {
68         TT_SUCF_ELECT,  // subgroup_uniform_control_flow using elect (subgroup_basic)
69         TT_SUCF_BALLOT, // subgroup_uniform_control_flow using ballot (subgroup_ballot)
70         TT_WUCF_ELECT,  // workgroup uniform control flow using elect (subgroup_basic)
71         TT_WUCF_BALLOT, // workgroup uniform control flow using ballot (subgroup_ballot)
72         TT_MAXIMAL,             // maximal reconvergence
73 } TestType;
74
75 struct CaseDef
76 {
77         TestType testType;
78         deUint32 maxNesting;
79         deUint32 seed;
80
81         bool isWUCF() const { return testType == TT_WUCF_ELECT || testType == TT_WUCF_BALLOT; }
82         bool isSUCF() const { return testType == TT_SUCF_ELECT || testType == TT_SUCF_BALLOT; }
83         bool isUCF() const { return isWUCF() || isSUCF(); }
84         bool isElect() const { return testType == TT_WUCF_ELECT || testType == TT_SUCF_ELECT; }
85 };
86
87 deUint64 subgroupSizeToMask(deUint32 subgroupSize)
88 {
89         if (subgroupSize == 64)
90                 return ~0ULL;
91         else
92                 return (1ULL << subgroupSize) - 1;
93 }
94
95 typedef std::bitset<128> bitset128;
96
97 // Take a 64-bit integer, mask it to the subgroup size, and then
98 // replicate it for each subgroup
99 bitset128 bitsetFromU64(deUint64 mask, deUint32 subgroupSize)
100 {
101         mask &= subgroupSizeToMask(subgroupSize);
102         bitset128 result(mask);
103         for (deUint32 i = 0; i < 128 / subgroupSize - 1; ++i)
104         {
105                 result = (result << subgroupSize) | bitset128(mask);
106         }
107         return result;
108 }
109
110 // Pick out the mask for the subgroup that invocationID is a member of
111 deUint64 bitsetToU64(const bitset128 &bitset, deUint32 subgroupSize, deUint32 invocationID)
112 {
113         bitset128 copy(bitset);
114         copy >>= (invocationID / subgroupSize) * subgroupSize;
115         copy &= bitset128(subgroupSizeToMask(subgroupSize));
116         deUint64 mask = copy.to_ullong();
117         mask &= subgroupSizeToMask(subgroupSize);
118         return mask;
119 }
120
121 class ReconvergenceTestInstance : public TestInstance
122 {
123 public:
124                                                 ReconvergenceTestInstance       (Context& context, const CaseDef& data);
125                                                 ~ReconvergenceTestInstance      (void);
126         tcu::TestStatus         iterate                         (void);
127 private:
128         CaseDef                 m_data;
129 };
130
131 ReconvergenceTestInstance::ReconvergenceTestInstance (Context& context, const CaseDef& data)
132         : vkt::TestInstance             (context)
133         , m_data                                (data)
134 {
135 }
136
137 ReconvergenceTestInstance::~ReconvergenceTestInstance (void)
138 {
139 }
140
141 class ReconvergenceTestCase : public TestCase
142 {
143         public:
144                                                                 ReconvergenceTestCase           (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
145                                                                 ~ReconvergenceTestCase  (void);
146         virtual void                            initPrograms            (SourceCollections& programCollection) const;
147         virtual TestInstance*           createInstance          (Context& context) const;
148         virtual void                            checkSupport            (Context& context) const;
149
150 private:
151         CaseDef                                 m_data;
152 };
153
154 ReconvergenceTestCase::ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
155         : vkt::TestCase (context, name, desc)
156         , m_data                (data)
157 {
158 }
159
160 ReconvergenceTestCase::~ReconvergenceTestCase   (void)
161 {
162 }
163
164 void ReconvergenceTestCase::checkSupport(Context& context) const
165 {
166         if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
167                 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
168
169         vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
170         deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
171         subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
172
173         vk::VkPhysicalDeviceProperties2 properties2;
174         deMemset(&properties2, 0, sizeof(properties2));
175         properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
176         properties2.pNext = &subgroupProperties;
177
178         context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties2);
179
180         if (m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BASIC_BIT))
181                 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BASIC_BIT not supported");
182
183         if (!m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT))
184                 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BALLOT_BIT not supported");
185
186         if (!(context.getSubgroupProperties().supportedStages & VK_SHADER_STAGE_COMPUTE_BIT))
187                 TCU_THROW(NotSupportedError, "compute stage does not support subgroup operations");
188
189         // Both subgroup- AND workgroup-uniform tests are enabled by shaderSubgroupUniformControlFlow.
190         if (m_data.isUCF() && !context.getShaderSubgroupUniformControlFlowFeatures().shaderSubgroupUniformControlFlow)
191                 TCU_THROW(NotSupportedError, "shaderSubgroupUniformControlFlow not supported");
192
193         // XXX TODO: Check for maximal reconvergence support
194         // if (m_data.testType == TT_MAXIMAL ...)
195 }
196
197 typedef enum
198 {
199         // store subgroupBallot().
200         // For OP_BALLOT, OP::caseValue is initialized to zero, and then
201         // set to 1 by simulate if the ballot is not workgroup- (or subgroup-_uniform.
202         // Only workgroup-uniform ballots are validated for correctness in
203         // WUCF modes.
204         OP_BALLOT,
205
206         // store literal constant
207         OP_STORE,
208
209         // if ((1ULL << gl_SubgroupInvocationID) & mask).
210         // Special case if mask = ~0ULL, converted into "if (inputA.a[idx] == idx)"
211         OP_IF_MASK,
212         OP_ELSE_MASK,
213         OP_ENDIF,
214
215         // if (gl_SubgroupInvocationID == loopIdxN) (where N is most nested loop counter)
216         OP_IF_LOOPCOUNT,
217         OP_ELSE_LOOPCOUNT,
218
219         // if (gl_LocalInvocationIndex >= inputA.a[N]) (where N is most nested loop counter)
220         OP_IF_LOCAL_INVOCATION_INDEX,
221         OP_ELSE_LOCAL_INVOCATION_INDEX,
222
223         // break/continue
224         OP_BREAK,
225         OP_CONTINUE,
226
227         // if (subgroupElect())
228         OP_ELECT,
229
230         // Loop with uniform number of iterations (read from a buffer)
231         OP_BEGIN_FOR_UNIF,
232         OP_END_FOR_UNIF,
233
234         // for (int loopIdxN = 0; loopIdxN < gl_SubgroupInvocationID + 1; ++loopIdxN)
235         OP_BEGIN_FOR_VAR,
236         OP_END_FOR_VAR,
237
238         // for (int loopIdxN = 0;; ++loopIdxN, OP_BALLOT)
239         // Always has an "if (subgroupElect()) break;" inside.
240         // Does the equivalent of OP_BALLOT in the continue construct
241         OP_BEGIN_FOR_INF,
242         OP_END_FOR_INF,
243
244         // do { loopIdxN++; ... } while (loopIdxN < uniformValue);
245         OP_BEGIN_DO_WHILE_UNIF,
246         OP_END_DO_WHILE_UNIF,
247
248         // do { ... } while (true);
249         // Always has an "if (subgroupElect()) break;" inside
250         OP_BEGIN_DO_WHILE_INF,
251         OP_END_DO_WHILE_INF,
252
253         // return;
254         OP_RETURN,
255
256         // function call (code bracketed by these is extracted into a separate function)
257         OP_CALL_BEGIN,
258         OP_CALL_END,
259
260         // switch statement on uniform value
261         OP_SWITCH_UNIF_BEGIN,
262         // switch statement on gl_SubgroupInvocationID & 3 value
263         OP_SWITCH_VAR_BEGIN,
264         // switch statement on loopIdx value
265         OP_SWITCH_LOOP_COUNT_BEGIN,
266
267         // case statement with a (invocation mask, case mask) pair
268         OP_CASE_MASK_BEGIN,
269         // case statement used for loop counter switches, with a value and a mask of loop iterations
270         OP_CASE_LOOP_COUNT_BEGIN,
271
272         // end of switch/case statement
273         OP_SWITCH_END,
274         OP_CASE_END,
275
276         // Extra code with no functional effect. Currently inculdes:
277         // - value 0: while (!subgroupElect()) {}
278         // - value 1: if (condition_that_is_false) { infinite loop }
279         OP_NOISE,
280 } OPType;
281
282 typedef enum
283 {
284         // Different if test conditions
285         IF_MASK,
286         IF_UNIFORM,
287         IF_LOOPCOUNT,
288         IF_LOCAL_INVOCATION_INDEX,
289 } IFType;
290
291 class OP
292 {
293 public:
294         OP(OPType _type, deUint64 _value, deUint32 _caseValue = 0)
295                 : type(_type), value(_value), caseValue(_caseValue)
296         {}
297
298         // The type of operation and an optional value.
299         // The value could be a mask for an if test, the index of the loop
300         // header for an end of loop, or the constant value for a store instruction
301         OPType type;
302         deUint64 value;
303         deUint32 caseValue;
304 };
305
306 static int findLSB (deUint64 value)
307 {
308         for (int i = 0; i < 64; i++)
309         {
310                 if (value & (1ULL<<i))
311                         return i;
312         }
313         return -1;
314 }
315
316 // For each subgroup, pick out the elected invocationID, and accumulate
317 // a bitset of all of them
318 static bitset128 bitsetElect (const bitset128& value, deInt32 subgroupSize)
319 {
320         bitset128 ret; // zero initialized
321
322         for (deInt32 i = 0; i < 128; i += subgroupSize)
323         {
324                 deUint64 mask = bitsetToU64(value, subgroupSize, i);
325                 int lsb = findLSB(mask);
326                 ret |= bitset128(lsb == -1 ? 0 : (1ULL << lsb)) << i;
327         }
328         return ret;
329 }
330
331 class RandomProgram
332 {
333 public:
334         RandomProgram(const CaseDef &c)
335                 : caseDef(c), numMasks(5), nesting(0), maxNesting(c.maxNesting), loopNesting(0), loopNestingThisFunction(0), callNesting(0), minCount(30), indent(0), isLoopInf(100, false), doneInfLoopBreak(100, false), storeBase(0x10000)
336         {
337                 deRandom_init(&rnd, caseDef.seed);
338                 for (int i = 0; i < numMasks; ++i)
339                         masks.push_back(deRandom_getUint64(&rnd));
340         }
341
342         const CaseDef caseDef;
343         deRandom rnd;
344         vector<OP> ops;
345         vector<deUint64> masks;
346         deInt32 numMasks;
347         deInt32 nesting;
348         deInt32 maxNesting;
349         deInt32 loopNesting;
350         deInt32 loopNestingThisFunction;
351         deInt32 callNesting;
352         deInt32 minCount;
353         deInt32 indent;
354         vector<bool> isLoopInf;
355         vector<bool> doneInfLoopBreak;
356         // Offset the value we use for OP_STORE, to avoid colliding with fully converged
357         // active masks with small subgroup sizes (e.g. with subgroupSize == 4, the SUCF
358         // tests need to know that 0xF is really an active mask).
359         deInt32 storeBase;
360
361         void genIf(IFType ifType)
362         {
363                 deUint32 maskIdx = deRandom_getUint32(&rnd) % numMasks;
364                 deUint64 mask = masks[maskIdx];
365                 if (ifType == IF_UNIFORM)
366                         mask = ~0ULL;
367
368                 deUint32 localIndexCmp = deRandom_getUint32(&rnd) % 128;
369                 if (ifType == IF_LOCAL_INVOCATION_INDEX)
370                         ops.push_back({OP_IF_LOCAL_INVOCATION_INDEX, localIndexCmp});
371                 else if (ifType == IF_LOOPCOUNT)
372                         ops.push_back({OP_IF_LOOPCOUNT, 0});
373                 else
374                         ops.push_back({OP_IF_MASK, mask});
375
376                 nesting++;
377
378                 size_t thenBegin = ops.size();
379                 pickOP(2);
380                 size_t thenEnd = ops.size();
381
382                 deUint32 randElse = (deRandom_getUint32(&rnd) % 100);
383                 if (randElse < 50)
384                 {
385                         if (ifType == IF_LOCAL_INVOCATION_INDEX)
386                                 ops.push_back({OP_ELSE_LOCAL_INVOCATION_INDEX, localIndexCmp});
387                         else if (ifType == IF_LOOPCOUNT)
388                                 ops.push_back({OP_ELSE_LOOPCOUNT, 0});
389                         else
390                                 ops.push_back({OP_ELSE_MASK, 0});
391
392                         if (randElse < 10)
393                         {
394                                 // Sometimes make the else block identical to the then block
395                                 for (size_t i = thenBegin; i < thenEnd; ++i)
396                                         ops.push_back(ops[i]);
397                         }
398                         else
399                                 pickOP(2);
400                 }
401                 ops.push_back({OP_ENDIF, 0});
402                 nesting--;
403         }
404
405         void genForUnif()
406         {
407                 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
408                 ops.push_back({OP_BEGIN_FOR_UNIF, iterCount});
409                 deUint32 loopheader = (deUint32)ops.size()-1;
410                 nesting++;
411                 loopNesting++;
412                 loopNestingThisFunction++;
413                 pickOP(2);
414                 ops.push_back({OP_END_FOR_UNIF, loopheader});
415                 loopNestingThisFunction--;
416                 loopNesting--;
417                 nesting--;
418         }
419
420         void genDoWhileUnif()
421         {
422                 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
423                 ops.push_back({OP_BEGIN_DO_WHILE_UNIF, iterCount});
424                 deUint32 loopheader = (deUint32)ops.size()-1;
425                 nesting++;
426                 loopNesting++;
427                 loopNestingThisFunction++;
428                 pickOP(2);
429                 ops.push_back({OP_END_DO_WHILE_UNIF, loopheader});
430                 loopNestingThisFunction--;
431                 loopNesting--;
432                 nesting--;
433         }
434
435         void genForVar()
436         {
437                 ops.push_back({OP_BEGIN_FOR_VAR, 0});
438                 deUint32 loopheader = (deUint32)ops.size()-1;
439                 nesting++;
440                 loopNesting++;
441                 loopNestingThisFunction++;
442                 pickOP(2);
443                 ops.push_back({OP_END_FOR_VAR, loopheader});
444                 loopNestingThisFunction--;
445                 loopNesting--;
446                 nesting--;
447         }
448
449         void genForInf()
450         {
451                 ops.push_back({OP_BEGIN_FOR_INF, 0});
452                 deUint32 loopheader = (deUint32)ops.size()-1;
453
454                 nesting++;
455                 loopNesting++;
456                 loopNestingThisFunction++;
457                 isLoopInf[loopNesting] = true;
458                 doneInfLoopBreak[loopNesting] = false;
459
460                 pickOP(2);
461
462                 genElect(true);
463                 doneInfLoopBreak[loopNesting] = true;
464
465                 pickOP(2);
466
467                 ops.push_back({OP_END_FOR_INF, loopheader});
468
469                 isLoopInf[loopNesting] = false;
470                 doneInfLoopBreak[loopNesting] = false;
471                 loopNestingThisFunction--;
472                 loopNesting--;
473                 nesting--;
474         }
475
476         void genDoWhileInf()
477         {
478                 ops.push_back({OP_BEGIN_DO_WHILE_INF, 0});
479                 deUint32 loopheader = (deUint32)ops.size()-1;
480
481                 nesting++;
482                 loopNesting++;
483                 loopNestingThisFunction++;
484                 isLoopInf[loopNesting] = true;
485                 doneInfLoopBreak[loopNesting] = false;
486
487                 pickOP(2);
488
489                 genElect(true);
490                 doneInfLoopBreak[loopNesting] = true;
491
492                 pickOP(2);
493
494                 ops.push_back({OP_END_DO_WHILE_INF, loopheader});
495
496                 isLoopInf[loopNesting] = false;
497                 doneInfLoopBreak[loopNesting] = false;
498                 loopNestingThisFunction--;
499                 loopNesting--;
500                 nesting--;
501         }
502
503         void genBreak()
504         {
505                 if (loopNestingThisFunction > 0)
506                 {
507                         // Sometimes put the break in a divergent if
508                         if ((deRandom_getUint32(&rnd) % 100) < 10)
509                         {
510                                 ops.push_back({OP_IF_MASK, masks[0]});
511                                 ops.push_back({OP_BREAK, 0});
512                                 ops.push_back({OP_ELSE_MASK, 0});
513                                 ops.push_back({OP_BREAK, 0});
514                                 ops.push_back({OP_ENDIF, 0});
515                         }
516                         else
517                                 ops.push_back({OP_BREAK, 0});
518                 }
519         }
520
521         void genContinue()
522         {
523                 // continues are allowed if we're in a loop and the loop is not infinite,
524                 // or if it is infinite and we've already done a subgroupElect+break.
525                 // However, adding more continues seems to reduce the failure rate, so
526                 // disable it for now
527                 if (loopNestingThisFunction > 0 && !(isLoopInf[loopNesting] /*&& !doneInfLoopBreak[loopNesting]*/))
528                 {
529                         // Sometimes put the continue in a divergent if
530                         if ((deRandom_getUint32(&rnd) % 100) < 10)
531                         {
532                                 ops.push_back({OP_IF_MASK, masks[0]});
533                                 ops.push_back({OP_CONTINUE, 0});
534                                 ops.push_back({OP_ELSE_MASK, 0});
535                                 ops.push_back({OP_CONTINUE, 0});
536                                 ops.push_back({OP_ENDIF, 0});
537                         }
538                         else
539                                 ops.push_back({OP_CONTINUE, 0});
540                 }
541         }
542
543         // doBreak is used to generate "if (subgroupElect()) { ... break; }" inside infinite loops
544         void genElect(bool doBreak)
545         {
546                 ops.push_back({OP_ELECT, 0});
547                 nesting++;
548                 if (doBreak)
549                 {
550                         // Put something interestign before the break
551                         optBallot();
552                         optBallot();
553                         if ((deRandom_getUint32(&rnd) % 100) < 10)
554                                 pickOP(1);
555
556                         // if we're in a function, sometimes  use return instead
557                         if (callNesting > 0 && (deRandom_getUint32(&rnd) % 100) < 30)
558                                 ops.push_back({OP_RETURN, 0});
559                         else
560                                 genBreak();
561
562                 }
563                 else
564                         pickOP(2);
565
566                 ops.push_back({OP_ENDIF, 0});
567                 nesting--;
568         }
569
570         void genReturn()
571         {
572                 deUint32 r = deRandom_getUint32(&rnd) % 100;
573                 if (nesting > 0 &&
574                         // Use return rarely in main, 20% of the time in a singly nested loop in a function
575                         // and 50% of the time in a multiply nested loop in a function
576                         (r < 5 ||
577                          (callNesting > 0 && loopNestingThisFunction > 0 && r < 20) ||
578                          (callNesting > 0 && loopNestingThisFunction > 1 && r < 50)))
579                 {
580                         optBallot();
581                         if ((deRandom_getUint32(&rnd) % 100) < 10)
582                         {
583                                 ops.push_back({OP_IF_MASK, masks[0]});
584                                 ops.push_back({OP_RETURN, 0});
585                                 ops.push_back({OP_ELSE_MASK, 0});
586                                 ops.push_back({OP_RETURN, 0});
587                                 ops.push_back({OP_ENDIF, 0});
588                         }
589                         else
590                                 ops.push_back({OP_RETURN, 0});
591                 }
592         }
593
594         // Generate a function call. Save and restore some loop information, which is used to
595         // determine when it's safe to use break/continue
596         void genCall()
597         {
598                 ops.push_back({OP_CALL_BEGIN, 0});
599                 callNesting++;
600                 nesting++;
601                 deInt32 saveLoopNestingThisFunction = loopNestingThisFunction;
602                 loopNestingThisFunction = 0;
603
604                 pickOP(2);
605
606                 loopNestingThisFunction = saveLoopNestingThisFunction;
607                 nesting--;
608                 callNesting--;
609                 ops.push_back({OP_CALL_END, 0});
610         }
611
612         // Generate switch on a uniform value:
613         // switch (inputA.a[r]) {
614         // case r+1: ... break; // should not execute
615         // case r:   ... break; // should branch uniformly
616         // case r+2: ... break; // should not execute
617         // }
618         void genSwitchUnif()
619         {
620                 deUint32 r = deRandom_getUint32(&rnd) % 5;
621                 ops.push_back({OP_SWITCH_UNIF_BEGIN, r});
622                 nesting++;
623
624                 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+1)});
625                 pickOP(1);
626                 ops.push_back({OP_CASE_END, 0});
627
628                 ops.push_back({OP_CASE_MASK_BEGIN, ~0ULL, 1u<<r});
629                 pickOP(2);
630                 ops.push_back({OP_CASE_END, 0});
631
632                 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+2)});
633                 pickOP(1);
634                 ops.push_back({OP_CASE_END, 0});
635
636                 ops.push_back({OP_SWITCH_END, 0});
637                 nesting--;
638         }
639
640         // switch (gl_SubgroupInvocationID & 3) with four unique targets
641         void genSwitchVar()
642         {
643                 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
644                 nesting++;
645
646                 ops.push_back({OP_CASE_MASK_BEGIN, 0x1111111111111111ULL, 1<<0});
647                 pickOP(1);
648                 ops.push_back({OP_CASE_END, 0});
649
650                 ops.push_back({OP_CASE_MASK_BEGIN, 0x2222222222222222ULL, 1<<1});
651                 pickOP(1);
652                 ops.push_back({OP_CASE_END, 0});
653
654                 ops.push_back({OP_CASE_MASK_BEGIN, 0x4444444444444444ULL, 1<<2});
655                 pickOP(1);
656                 ops.push_back({OP_CASE_END, 0});
657
658                 ops.push_back({OP_CASE_MASK_BEGIN, 0x8888888888888888ULL, 1<<3});
659                 pickOP(1);
660                 ops.push_back({OP_CASE_END, 0});
661
662                 ops.push_back({OP_SWITCH_END, 0});
663                 nesting--;
664         }
665
666         // switch (gl_SubgroupInvocationID & 3) with two shared targets.
667         // XXX TODO: The test considers these two targets to remain converged,
668         // though we haven't agreed to that behavior yet.
669         void genSwitchMulticase()
670         {
671                 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
672                 nesting++;
673
674                 ops.push_back({OP_CASE_MASK_BEGIN, 0x3333333333333333ULL, (1<<0)|(1<<1)});
675                 pickOP(2);
676                 ops.push_back({OP_CASE_END, 0});
677
678                 ops.push_back({OP_CASE_MASK_BEGIN, 0xCCCCCCCCCCCCCCCCULL, (1<<2)|(1<<3)});
679                 pickOP(2);
680                 ops.push_back({OP_CASE_END, 0});
681
682                 ops.push_back({OP_SWITCH_END, 0});
683                 nesting--;
684         }
685
686         // switch (loopIdxN) {
687         // case 1:  ... break;
688         // case 2:  ... break;
689         // default: ... break;
690         // }
691         void genSwitchLoopCount()
692         {
693                 deUint32 r = deRandom_getUint32(&rnd) % loopNesting;
694                 ops.push_back({OP_SWITCH_LOOP_COUNT_BEGIN, r});
695                 nesting++;
696
697                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<1, 1});
698                 pickOP(1);
699                 ops.push_back({OP_CASE_END, 0});
700
701                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<2, 2});
702                 pickOP(1);
703                 ops.push_back({OP_CASE_END, 0});
704
705                 // default:
706                 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, ~6ULL, 0xFFFFFFFF});
707                 pickOP(1);
708                 ops.push_back({OP_CASE_END, 0});
709
710                 ops.push_back({OP_SWITCH_END, 0});
711                 nesting--;
712         }
713
714         void pickOP(deUint32 count)
715         {
716                 // Pick "count" instructions. These can recursively insert more instructions,
717                 // so "count" is just a seed
718                 for (deUint32 i = 0; i < count; ++i)
719                 {
720                         optBallot();
721                         if (nesting < maxNesting)
722                         {
723                                 deUint32 r = deRandom_getUint32(&rnd) % 11;
724                                 switch (r)
725                                 {
726                                 default:
727                                         DE_ASSERT(0);
728                                         // fallthrough
729                                 case 2:
730                                         if (loopNesting)
731                                         {
732                                                 genIf(IF_LOOPCOUNT);
733                                                 break;
734                                         }
735                                         // fallthrough
736                                 case 10:
737                                         genIf(IF_LOCAL_INVOCATION_INDEX);
738                                         break;
739                                 case 0:
740                                         genIf(IF_MASK);
741                                         break;
742                                 case 1:
743                                         genIf(IF_UNIFORM);
744                                         break;
745                                 case 3:
746                                         {
747                                                 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
748                                                 if (loopNesting <= 3)
749                                                 {
750                                                         deUint32 r2 = deRandom_getUint32(&rnd) % 3;
751                                                         switch (r2)
752                                                         {
753                                                         default: DE_ASSERT(0); // fallthrough
754                                                         case 0: genForUnif(); break;
755                                                         case 1: genForInf(); break;
756                                                         case 2: genForVar(); break;
757                                                         }
758                                                 }
759                                         }
760                                         break;
761                                 case 4:
762                                         genBreak();
763                                         break;
764                                 case 5:
765                                         genContinue();
766                                         break;
767                                 case 6:
768                                         genElect(false);
769                                         break;
770                                 case 7:
771                                         {
772                                                 deUint32 r2 = deRandom_getUint32(&rnd) % 5;
773                                                 if (r2 == 0 && callNesting == 0 && nesting < maxNesting - 2)
774                                                         genCall();
775                                                 else
776                                                         genReturn();
777                                                 break;
778                                         }
779                                 case 8:
780                                         {
781                                                 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
782                                                 if (loopNesting <= 3)
783                                                 {
784                                                         deUint32 r2 = deRandom_getUint32(&rnd) % 2;
785                                                         switch (r2)
786                                                         {
787                                                         default: DE_ASSERT(0); // fallthrough
788                                                         case 0: genDoWhileUnif(); break;
789                                                         case 1: genDoWhileInf(); break;
790                                                         }
791                                                 }
792                                         }
793                                         break;
794                                 case 9:
795                                         {
796                                                 deUint32 r2 = deRandom_getUint32(&rnd) % 4;
797                                                 switch (r2)
798                                                 {
799                                                 default:
800                                                         DE_ASSERT(0);
801                                                         // fallthrough
802                                                 case 0:
803                                                         genSwitchUnif();
804                                                         break;
805                                                 case 1:
806                                                         if (loopNesting > 0) {
807                                                                 genSwitchLoopCount();
808                                                                 break;
809                                                         }
810                                                         // fallthrough
811                                                 case 2:
812                                                         if (caseDef.testType != TT_MAXIMAL)
813                                                         {
814                                                                 // multicase doesn't have fully-defined behavior for MAXIMAL tests,
815                                                                 // but does for SUCF tests
816                                                                 genSwitchMulticase();
817                                                                 break;
818                                                         }
819                                                         // fallthrough
820                                                 case 3:
821                                                         genSwitchVar();
822                                                         break;
823                                                 }
824                                         }
825                                         break;
826                                 }
827                         }
828                         optBallot();
829                 }
830         }
831
832         void optBallot()
833         {
834                 // optionally insert ballots, stores, and noise. Ballots and stores are used to determine
835                 // correctness.
836                 if ((deRandom_getUint32(&rnd) % 100) < 20)
837                 {
838                         if (ops.size() < 2 ||
839                            !(ops[ops.size()-1].type == OP_BALLOT ||
840                                  (ops[ops.size()-1].type == OP_STORE && ops[ops.size()-2].type == OP_BALLOT)))
841                         {
842                                 // do a store along with each ballot, so we can correlate where
843                                 // the ballot came from
844                                 if (caseDef.testType != TT_MAXIMAL)
845                                         ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
846                                 ops.push_back({OP_BALLOT, 0});
847                         }
848                 }
849
850                 if ((deRandom_getUint32(&rnd) % 100) < 10)
851                 {
852                         if (ops.size() < 2 ||
853                            !(ops[ops.size()-1].type == OP_STORE ||
854                                  (ops[ops.size()-1].type == OP_BALLOT && ops[ops.size()-2].type == OP_STORE)))
855                         {
856                                 // SUCF does a store with every ballot. Don't bloat the code by adding more.
857                                 if (caseDef.testType == TT_MAXIMAL)
858                                         ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
859                         }
860                 }
861
862                 deUint32 r = deRandom_getUint32(&rnd) % 10000;
863                 if (r < 3)
864                         ops.push_back({OP_NOISE, 0});
865                 else if (r < 10)
866                         ops.push_back({OP_NOISE, 1});
867         }
868
869         void generateRandomProgram()
870         {
871                 do {
872                         ops.clear();
873                         while ((deInt32)ops.size() < minCount)
874                                 pickOP(1);
875
876                         // Retry until the program has some UCF results in it
877                         if (caseDef.isUCF())
878                         {
879                                 const deUint32 invocationStride = 128;
880                                 // Simulate for all subgroup sizes, to determine whether OP_BALLOTs are nonuniform
881                                 for (deInt32 subgroupSize = 4; subgroupSize <= 64; subgroupSize *= 2) {
882                                         simulate(true, subgroupSize, invocationStride, DE_NULL);
883                                 }
884                         }
885                 } while (caseDef.isUCF() && !hasUCF());
886         }
887
888         void printIndent(std::stringstream &css)
889         {
890                 for (deInt32 i = 0; i < indent; ++i)
891                         css << " ";
892         }
893
894         std::string genPartitionBallot()
895         {
896                 std::stringstream ss;
897                 ss << "subgroupBallot(true).xy";
898                 return ss.str();
899         }
900
901         void printBallot(std::stringstream *css)
902         {
903                 *css << "outputC.loc[gl_LocalInvocationIndex]++,";
904                 // When inside loop(s), use partitionBallot rather than subgroupBallot to compute
905                 // a ballot, to make sure the ballot is "diverged enough". Don't do this for
906                 // subgroup_uniform_control_flow, since we only validate results that must be fully
907                 // reconverged.
908                 if (loopNesting > 0 && caseDef.testType == TT_MAXIMAL)
909                 {
910                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = " << genPartitionBallot();
911                 }
912                 else if (caseDef.isElect())
913                 {
914                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = elect()";
915                 }
916                 else
917                 {
918                         *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = subgroupBallot(true).xy";
919                 }
920         }
921
922         void genCode(std::stringstream &functions, std::stringstream &main)
923         {
924                 std::stringstream *css = &main;
925                 indent = 4;
926                 loopNesting = 0;
927                 int funcNum = 0;
928                 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
929                 {
930                         switch (ops[i].type)
931                         {
932                         case OP_IF_MASK:
933                                 printIndent(*css);
934                                 if (ops[i].value == ~0ULL)
935                                 {
936                                         // This equality test will always succeed, since inputA.a[i] == i
937                                         int idx = deRandom_getUint32(&rnd) % 4;
938                                         *css << "if (inputA.a[" << idx << "] == " << idx << ") {\n";
939                                 }
940                                 else
941                                         *css << "if (testBit(uvec2(0x" << std::hex << (ops[i].value & 0xFFFFFFFF) << ", 0x" << (ops[i].value >> 32) << "), gl_SubgroupInvocationID)) {\n";
942
943                                 indent += 4;
944                                 break;
945                         case OP_IF_LOOPCOUNT:
946                                 printIndent(*css); *css << "if (gl_SubgroupInvocationID == loopIdx" << loopNesting - 1 << ") {\n";
947                                 indent += 4;
948                                 break;
949                         case OP_IF_LOCAL_INVOCATION_INDEX:
950                                 printIndent(*css); *css << "if (gl_LocalInvocationIndex >= inputA.a[0x" << std::hex << ops[i].value << "]) {\n";
951                                 indent += 4;
952                                 break;
953                         case OP_ELSE_MASK:
954                         case OP_ELSE_LOOPCOUNT:
955                         case OP_ELSE_LOCAL_INVOCATION_INDEX:
956                                 indent -= 4;
957                                 printIndent(*css); *css << "} else {\n";
958                                 indent += 4;
959                                 break;
960                         case OP_ENDIF:
961                                 indent -= 4;
962                                 printIndent(*css); *css << "}\n";
963                                 break;
964                         case OP_BALLOT:
965                                 printIndent(*css); printBallot(css); *css << ";\n";
966                                 break;
967                         case OP_STORE:
968                                 printIndent(*css); *css << "outputC.loc[gl_LocalInvocationIndex]++;\n";
969                                 printIndent(*css); *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = 0x" << std::hex << ops[i].value << ";\n";
970                                 break;
971                         case OP_BEGIN_FOR_UNIF:
972                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
973                                 printIndent(*css); *css << "         loopIdx" << loopNesting << " < inputA.a[" << ops[i].value << "];\n";
974                                 printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
975                                 indent += 4;
976                                 loopNesting++;
977                                 break;
978                         case OP_END_FOR_UNIF:
979                                 loopNesting--;
980                                 indent -= 4;
981                                 printIndent(*css); *css << "}\n";
982                                 break;
983                         case OP_BEGIN_DO_WHILE_UNIF:
984                                 printIndent(*css); *css << "{\n";
985                                 indent += 4;
986                                 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
987                                 printIndent(*css); *css << "do {\n";
988                                 indent += 4;
989                                 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
990                                 loopNesting++;
991                                 break;
992                         case OP_BEGIN_DO_WHILE_INF:
993                                 printIndent(*css); *css << "{\n";
994                                 indent += 4;
995                                 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
996                                 printIndent(*css); *css << "do {\n";
997                                 indent += 4;
998                                 loopNesting++;
999                                 break;
1000                         case OP_END_DO_WHILE_UNIF:
1001                                 loopNesting--;
1002                                 indent -= 4;
1003                                 printIndent(*css); *css << "} while (loopIdx" << loopNesting << " < inputA.a[" << ops[(deUint32)ops[i].value].value << "]);\n";
1004                                 indent -= 4;
1005                                 printIndent(*css); *css << "}\n";
1006                                 break;
1007                         case OP_END_DO_WHILE_INF:
1008                                 loopNesting--;
1009                                 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
1010                                 indent -= 4;
1011                                 printIndent(*css); *css << "} while (true);\n";
1012                                 indent -= 4;
1013                                 printIndent(*css); *css << "}\n";
1014                                 break;
1015                         case OP_BEGIN_FOR_VAR:
1016                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
1017                                 printIndent(*css); *css << "         loopIdx" << loopNesting << " < gl_SubgroupInvocationID + 1;\n";
1018                                 printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
1019                                 indent += 4;
1020                                 loopNesting++;
1021                                 break;
1022                         case OP_END_FOR_VAR:
1023                                 loopNesting--;
1024                                 indent -= 4;
1025                                 printIndent(*css); *css << "}\n";
1026                                 break;
1027                         case OP_BEGIN_FOR_INF:
1028                                 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;;loopIdx" << loopNesting << "++,";
1029                                 loopNesting++;
1030                                 printBallot(css);
1031                                 *css << ") {\n";
1032                                 indent += 4;
1033                                 break;
1034                         case OP_END_FOR_INF:
1035                                 loopNesting--;
1036                                 indent -= 4;
1037                                 printIndent(*css); *css << "}\n";
1038                                 break;
1039                         case OP_BREAK:
1040                                 printIndent(*css); *css << "break;\n";
1041                                 break;
1042                         case OP_CONTINUE:
1043                                 printIndent(*css); *css << "continue;\n";
1044                                 break;
1045                         case OP_ELECT:
1046                                 printIndent(*css); *css << "if (subgroupElect()) {\n";
1047                                 indent += 4;
1048                                 break;
1049                         case OP_RETURN:
1050                                 printIndent(*css); *css << "return;\n";
1051                                 break;
1052                         case OP_CALL_BEGIN:
1053                                 printIndent(*css); *css << "func" << funcNum << "(";
1054                                 for (deInt32 n = 0; n < loopNesting; ++n)
1055                                 {
1056                                         *css << "loopIdx" << n;
1057                                         if (n != loopNesting - 1)
1058                                                 *css << ", ";
1059                                 }
1060                                 *css << ");\n";
1061                                 css = &functions;
1062                                 printIndent(*css); *css << "void func" << funcNum << "(";
1063                                 for (deInt32 n = 0; n < loopNesting; ++n)
1064                                 {
1065                                         *css << "int loopIdx" << n;
1066                                         if (n != loopNesting - 1)
1067                                                 *css << ", ";
1068                                 }
1069                                 *css << ") {\n";
1070                                 indent += 4;
1071                                 funcNum++;
1072                                 break;
1073                         case OP_CALL_END:
1074                                 indent -= 4;
1075                                 printIndent(*css); *css << "}\n";
1076                                 css = &main;
1077                                 break;
1078                         case OP_NOISE:
1079                                 if (ops[i].value == 0)
1080                                 {
1081                                         printIndent(*css); *css << "while (!subgroupElect()) {}\n";
1082                                 }
1083                                 else
1084                                 {
1085                                         printIndent(*css); *css << "if (inputA.a[0] == 12345) {\n";
1086                                         indent += 4;
1087                                         printIndent(*css); *css << "while (true) {\n";
1088                                         indent += 4;
1089                                         printIndent(*css); printBallot(css); *css << ";\n";
1090                                         indent -= 4;
1091                                         printIndent(*css); *css << "}\n";
1092                                         indent -= 4;
1093                                         printIndent(*css); *css << "}\n";
1094                                 }
1095                                 break;
1096                         case OP_SWITCH_UNIF_BEGIN:
1097                                 printIndent(*css); *css << "switch (inputA.a[" << ops[i].value << "]) {\n";
1098                                 indent += 4;
1099                                 break;
1100                         case OP_SWITCH_VAR_BEGIN:
1101                                 printIndent(*css); *css << "switch (gl_SubgroupInvocationID & 3) {\n";
1102                                 indent += 4;
1103                                 break;
1104                         case OP_SWITCH_LOOP_COUNT_BEGIN:
1105                                 printIndent(*css); *css << "switch (loopIdx" << ops[i].value << ") {\n";
1106                                 indent += 4;
1107                                 break;
1108                         case OP_SWITCH_END:
1109                                 indent -= 4;
1110                                 printIndent(*css); *css << "}\n";
1111                                 break;
1112                         case OP_CASE_MASK_BEGIN:
1113                                 for (deInt32 b = 0; b < 32; ++b)
1114                                 {
1115                                         if ((1u<<b) & ops[i].caseValue)
1116                                         {
1117                                                 printIndent(*css); *css << "case " << b << ":\n";
1118                                         }
1119                                 }
1120                                 printIndent(*css); *css << "{\n";
1121                                 indent += 4;
1122                                 break;
1123                         case OP_CASE_LOOP_COUNT_BEGIN:
1124                                 if (ops[i].caseValue == 0xFFFFFFFF)
1125                                 {
1126                                         printIndent(*css); *css << "default: {\n";
1127                                 }
1128                                 else
1129                                 {
1130                                         printIndent(*css); *css << "case " << ops[i].caseValue << ": {\n";
1131                                 }
1132                                 indent += 4;
1133                                 break;
1134                         case OP_CASE_END:
1135                                 printIndent(*css); *css << "break;\n";
1136                                 indent -= 4;
1137                                 printIndent(*css); *css << "}\n";
1138                                 break;
1139                         default:
1140                                 DE_ASSERT(0);
1141                                 break;
1142                         }
1143                 }
1144         }
1145
1146         // Simulate execution of the program. If countOnly is true, just return
1147         // the max number of outputs written. If it's false, store out the result
1148         // values to ref
1149         deUint32 simulate(bool countOnly, deUint32 subgroupSize, deUint32 invocationStride, deUint64 *ref)
1150         {
1151                 // State of the subgroup at each level of nesting
1152                 struct SubgroupState
1153                 {
1154                         // Currently executing
1155                         bitset128 activeMask;
1156                         // Have executed a continue instruction in this loop
1157                         bitset128 continueMask;
1158                         // index of the current if test or loop header
1159                         deUint32 header;
1160                         // number of loop iterations performed
1161                         deUint32 tripCount;
1162                         // is this nesting a loop?
1163                         deUint32 isLoop;
1164                         // is this nesting a function call?
1165                         deUint32 isCall;
1166                         // is this nesting a switch?
1167                         deUint32 isSwitch;
1168                 };
1169                 SubgroupState stateStack[10] = {};
1170
1171                 const deUint64 fullSubgroupMask = subgroupSizeToMask(subgroupSize);
1172
1173                 // Per-invocation output location counters
1174                 deUint32 outLoc[128] = {0};
1175
1176                 nesting = 0;
1177                 loopNesting = 0;
1178                 stateStack[nesting].activeMask = ~bitset128(); // initialized to ~0
1179
1180                 deInt32 i = 0;
1181                 while (i < (deInt32)ops.size())
1182                 {
1183                         switch (ops[i].type)
1184                         {
1185                         case OP_BALLOT:
1186
1187                                 // Flag that this ballot is workgroup-nonuniform
1188                                 if (caseDef.isWUCF() && stateStack[nesting].activeMask.any() && !stateStack[nesting].activeMask.all())
1189                                         ops[i].caseValue = 1;
1190
1191                                 if (caseDef.isSUCF())
1192                                 {
1193                                         for (deUint32 id = 0; id < 128; id += subgroupSize)
1194                                         {
1195                                                 deUint64 subgroupMask = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1196                                                 // Flag that this ballot is subgroup-nonuniform
1197                                                 if (subgroupMask != 0 && subgroupMask != fullSubgroupMask)
1198                                                         ops[i].caseValue = 1;
1199                                         }
1200                                 }
1201
1202                                 for (deUint32 id = 0; id < 128; ++id)
1203                                 {
1204                                         if (stateStack[nesting].activeMask.test(id))
1205                                         {
1206                                                 if (countOnly)
1207                                                 {
1208                                                         outLoc[id]++;
1209                                                 }
1210                                                 else
1211                                                 {
1212                                                         if (ops[i].caseValue)
1213                                                         {
1214                                                                 // Emit a magic value to indicate that we shouldn't validate this ballot
1215                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id);
1216                                                         }
1217                                                         else
1218                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1219                                                 }
1220                                         }
1221                                 }
1222                                 break;
1223                         case OP_STORE:
1224                                 for (deUint32 id = 0; id < 128; ++id)
1225                                 {
1226                                         if (stateStack[nesting].activeMask.test(id))
1227                                         {
1228                                                 if (countOnly)
1229                                                         outLoc[id]++;
1230                                                 else
1231                                                         ref[(outLoc[id]++)*invocationStride + id] = ops[i].value;
1232                                         }
1233                                 }
1234                                 break;
1235                         case OP_IF_MASK:
1236                                 nesting++;
1237                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1238                                 stateStack[nesting].header = i;
1239                                 stateStack[nesting].isLoop = 0;
1240                                 stateStack[nesting].isSwitch = 0;
1241                                 break;
1242                         case OP_ELSE_MASK:
1243                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64(ops[stateStack[nesting].header].value, subgroupSize);
1244                                 break;
1245                         case OP_IF_LOOPCOUNT:
1246                                 {
1247                                         deUint32 n = nesting;
1248                                         while (!stateStack[n].isLoop)
1249                                                 n--;
1250
1251                                         nesting++;
1252                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1253                                         stateStack[nesting].header = i;
1254                                         stateStack[nesting].isLoop = 0;
1255                                         stateStack[nesting].isSwitch = 0;
1256                                         break;
1257                                 }
1258                         case OP_ELSE_LOOPCOUNT:
1259                                 {
1260                                         deUint32 n = nesting;
1261                                         while (!stateStack[n].isLoop)
1262                                                 n--;
1263
1264                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1265                                         break;
1266                                 }
1267                         case OP_IF_LOCAL_INVOCATION_INDEX:
1268                                 {
1269                                         // all bits >= N
1270                                         bitset128 mask(0);
1271                                         for (deInt32 j = (deInt32)ops[i].value; j < 128; ++j)
1272                                                 mask.set(j);
1273
1274                                         nesting++;
1275                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1276                                         stateStack[nesting].header = i;
1277                                         stateStack[nesting].isLoop = 0;
1278                                         stateStack[nesting].isSwitch = 0;
1279                                         break;
1280                                 }
1281                         case OP_ELSE_LOCAL_INVOCATION_INDEX:
1282                                 {
1283                                         // all bits < N
1284                                         bitset128 mask(0);
1285                                         for (deInt32 j = 0; j < (deInt32)ops[i].value; ++j)
1286                                                 mask.set(j);
1287
1288                                         stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1289                                         break;
1290                                 }
1291                         case OP_ENDIF:
1292                                 nesting--;
1293                                 break;
1294                         case OP_BEGIN_FOR_UNIF:
1295                                 // XXX TODO: We don't handle a for loop with zero iterations
1296                                 nesting++;
1297                                 loopNesting++;
1298                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1299                                 stateStack[nesting].header = i;
1300                                 stateStack[nesting].tripCount = 0;
1301                                 stateStack[nesting].isLoop = 1;
1302                                 stateStack[nesting].isSwitch = 0;
1303                                 stateStack[nesting].continueMask = 0;
1304                                 break;
1305                         case OP_END_FOR_UNIF:
1306                                 stateStack[nesting].tripCount++;
1307                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1308                                 stateStack[nesting].continueMask = 0;
1309                                 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1310                                         stateStack[nesting].activeMask.any())
1311                                 {
1312                                         i = stateStack[nesting].header+1;
1313                                         continue;
1314                                 }
1315                                 else
1316                                 {
1317                                         loopNesting--;
1318                                         nesting--;
1319                                 }
1320                                 break;
1321                         case OP_BEGIN_DO_WHILE_UNIF:
1322                                 // XXX TODO: We don't handle a for loop with zero iterations
1323                                 nesting++;
1324                                 loopNesting++;
1325                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1326                                 stateStack[nesting].header = i;
1327                                 stateStack[nesting].tripCount = 1;
1328                                 stateStack[nesting].isLoop = 1;
1329                                 stateStack[nesting].isSwitch = 0;
1330                                 stateStack[nesting].continueMask = 0;
1331                                 break;
1332                         case OP_END_DO_WHILE_UNIF:
1333                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1334                                 stateStack[nesting].continueMask = 0;
1335                                 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1336                                         stateStack[nesting].activeMask.any())
1337                                 {
1338                                         i = stateStack[nesting].header+1;
1339                                         stateStack[nesting].tripCount++;
1340                                         continue;
1341                                 }
1342                                 else
1343                                 {
1344                                         loopNesting--;
1345                                         nesting--;
1346                                 }
1347                                 break;
1348                         case OP_BEGIN_FOR_VAR:
1349                                 // XXX TODO: We don't handle a for loop with zero iterations
1350                                 nesting++;
1351                                 loopNesting++;
1352                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1353                                 stateStack[nesting].header = i;
1354                                 stateStack[nesting].tripCount = 0;
1355                                 stateStack[nesting].isLoop = 1;
1356                                 stateStack[nesting].isSwitch = 0;
1357                                 stateStack[nesting].continueMask = 0;
1358                                 break;
1359                         case OP_END_FOR_VAR:
1360                                 stateStack[nesting].tripCount++;
1361                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1362                                 stateStack[nesting].continueMask = 0;
1363                                 stateStack[nesting].activeMask &= bitsetFromU64(stateStack[nesting].tripCount == subgroupSize ? 0 : ~((1ULL << (stateStack[nesting].tripCount)) - 1), subgroupSize);
1364                                 if (stateStack[nesting].activeMask.any())
1365                                 {
1366                                         i = stateStack[nesting].header+1;
1367                                         continue;
1368                                 }
1369                                 else
1370                                 {
1371                                         loopNesting--;
1372                                         nesting--;
1373                                 }
1374                                 break;
1375                         case OP_BEGIN_FOR_INF:
1376                         case OP_BEGIN_DO_WHILE_INF:
1377                                 nesting++;
1378                                 loopNesting++;
1379                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1380                                 stateStack[nesting].header = i;
1381                                 stateStack[nesting].tripCount = 0;
1382                                 stateStack[nesting].isLoop = 1;
1383                                 stateStack[nesting].isSwitch = 0;
1384                                 stateStack[nesting].continueMask = 0;
1385                                 break;
1386                         case OP_END_FOR_INF:
1387                                 stateStack[nesting].tripCount++;
1388                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1389                                 stateStack[nesting].continueMask = 0;
1390                                 if (stateStack[nesting].activeMask.any())
1391                                 {
1392                                         // output expected OP_BALLOT values
1393                                         for (deUint32 id = 0; id < 128; ++id)
1394                                         {
1395                                                 if (stateStack[nesting].activeMask.test(id))
1396                                                 {
1397                                                         if (countOnly)
1398                                                                 outLoc[id]++;
1399                                                         else
1400                                                                 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1401                                                 }
1402                                         }
1403
1404                                         i = stateStack[nesting].header+1;
1405                                         continue;
1406                                 }
1407                                 else
1408                                 {
1409                                         loopNesting--;
1410                                         nesting--;
1411                                 }
1412                                 break;
1413                         case OP_END_DO_WHILE_INF:
1414                                 stateStack[nesting].tripCount++;
1415                                 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1416                                 stateStack[nesting].continueMask = 0;
1417                                 if (stateStack[nesting].activeMask.any())
1418                                 {
1419                                         i = stateStack[nesting].header+1;
1420                                         continue;
1421                                 }
1422                                 else
1423                                 {
1424                                         loopNesting--;
1425                                         nesting--;
1426                                 }
1427                                 break;
1428                         case OP_BREAK:
1429                                 {
1430                                         deUint32 n = nesting;
1431                                         bitset128 mask = stateStack[nesting].activeMask;
1432                                         while (true)
1433                                         {
1434                                                 stateStack[n].activeMask &= ~mask;
1435                                                 if (stateStack[n].isLoop || stateStack[n].isSwitch)
1436                                                         break;
1437
1438                                                 n--;
1439                                         }
1440                                 }
1441                                 break;
1442                         case OP_CONTINUE:
1443                                 {
1444                                         deUint32 n = nesting;
1445                                         bitset128 mask = stateStack[nesting].activeMask;
1446                                         while (true)
1447                                         {
1448                                                 stateStack[n].activeMask &= ~mask;
1449                                                 if (stateStack[n].isLoop)
1450                                                 {
1451                                                         stateStack[n].continueMask |= mask;
1452                                                         break;
1453                                                 }
1454                                                 n--;
1455                                         }
1456                                 }
1457                                 break;
1458                         case OP_ELECT:
1459                                 {
1460                                         nesting++;
1461                                         stateStack[nesting].activeMask = bitsetElect(stateStack[nesting-1].activeMask, subgroupSize);
1462                                         stateStack[nesting].header = i;
1463                                         stateStack[nesting].isLoop = 0;
1464                                         stateStack[nesting].isSwitch = 0;
1465                                 }
1466                                 break;
1467                         case OP_RETURN:
1468                                 {
1469                                         bitset128 mask = stateStack[nesting].activeMask;
1470                                         for (deInt32 n = nesting; n >= 0; --n)
1471                                         {
1472                                                 stateStack[n].activeMask &= ~mask;
1473                                                 if (stateStack[n].isCall)
1474                                                         break;
1475                                         }
1476                                 }
1477                                 break;
1478
1479                         case OP_CALL_BEGIN:
1480                                 nesting++;
1481                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1482                                 stateStack[nesting].isLoop = 0;
1483                                 stateStack[nesting].isSwitch = 0;
1484                                 stateStack[nesting].isCall = 1;
1485                                 break;
1486                         case OP_CALL_END:
1487                                 stateStack[nesting].isCall = 0;
1488                                 nesting--;
1489                                 break;
1490                         case OP_NOISE:
1491                                 break;
1492
1493                         case OP_SWITCH_UNIF_BEGIN:
1494                         case OP_SWITCH_VAR_BEGIN:
1495                         case OP_SWITCH_LOOP_COUNT_BEGIN:
1496                                 nesting++;
1497                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1498                                 stateStack[nesting].header = i;
1499                                 stateStack[nesting].isLoop = 0;
1500                                 stateStack[nesting].isSwitch = 1;
1501                                 break;
1502                         case OP_SWITCH_END:
1503                                 nesting--;
1504                                 break;
1505                         case OP_CASE_MASK_BEGIN:
1506                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1507                                 break;
1508                         case OP_CASE_LOOP_COUNT_BEGIN:
1509                                 {
1510                                         deUint32 n = nesting;
1511                                         deUint32 l = loopNesting;
1512
1513                                         while (true)
1514                                         {
1515                                                 if (stateStack[n].isLoop)
1516                                                 {
1517                                                         l--;
1518                                                         if (l == ops[stateStack[nesting].header].value)
1519                                                                 break;
1520                                                 }
1521                                                 n--;
1522                                         }
1523
1524                                         if ((1ULL << stateStack[n].tripCount) & ops[i].value)
1525                                                 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1526                                         else
1527                                                 stateStack[nesting].activeMask = 0;
1528                                         break;
1529                                 }
1530                         case OP_CASE_END:
1531                                 break;
1532
1533                         default:
1534                                 DE_ASSERT(0);
1535                                 break;
1536                         }
1537                         i++;
1538                 }
1539                 deUint32 maxLoc = 0;
1540                 for (deUint32 id = 0; id < ARRAYSIZE(outLoc); ++id)
1541                         maxLoc = de::max(maxLoc, outLoc[id]);
1542
1543                 return maxLoc;
1544         }
1545
1546         bool hasUCF() const
1547         {
1548                 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
1549                 {
1550                         if (ops[i].type == OP_BALLOT && ops[i].caseValue == 0)
1551                                 return true;
1552                 }
1553                 return false;
1554         }
1555 };
1556
1557 void ReconvergenceTestCase::initPrograms (SourceCollections& programCollection) const
1558 {
1559         RandomProgram program(m_data);
1560         program.generateRandomProgram();
1561
1562         std::stringstream css;
1563         css << "#version 450 core\n";
1564         css << "#extension GL_KHR_shader_subgroup_ballot : enable\n";
1565         css << "#extension GL_KHR_shader_subgroup_vote : enable\n";
1566         css << "#extension GL_NV_shader_subgroup_partitioned : enable\n";
1567         css << "#extension GL_EXT_subgroup_uniform_control_flow : enable\n";
1568         css << "layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n";
1569         css << "layout(set=0, binding=0) coherent buffer InputA { uint a[]; } inputA;\n";
1570         css << "layout(set=0, binding=1) coherent buffer OutputB { uvec2 b[]; } outputB;\n";
1571         css << "layout(set=0, binding=2) coherent buffer OutputC { uint loc[]; } outputC;\n";
1572         css << "layout(push_constant) uniform PC {\n"
1573                         "   // set to the real stride when writing out ballots, or zero when just counting\n"
1574                         "   int invocationStride;\n"
1575                         "};\n";
1576         css << "int outLoc = 0;\n";
1577
1578         css << "bool testBit(uvec2 mask, uint bit) { return (bit < 32) ? ((mask.x >> bit) & 1) != 0 : ((mask.y >> (bit-32)) & 1) != 0; }\n";
1579
1580         css << "uint elect() { return int(subgroupElect()) + 1; }\n";
1581
1582         std::stringstream functions, main;
1583         program.genCode(functions, main);
1584
1585         css << functions.str() << "\n\n";
1586
1587         css <<
1588                 "void main()\n"
1589                 << (m_data.isSUCF() ? "[[subgroup_uniform_control_flow]]\n" : "") <<
1590                 "{\n";
1591
1592         css << main.str() << "\n\n";
1593
1594         css << "}\n";
1595
1596         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
1597
1598         programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
1599 }
1600
1601 TestInstance* ReconvergenceTestCase::createInstance (Context& context) const
1602 {
1603         return new ReconvergenceTestInstance(context, m_data);
1604 }
1605
1606 tcu::TestStatus ReconvergenceTestInstance::iterate (void)
1607 {
1608         const DeviceInterface&  vk                                              = m_context.getDeviceInterface();
1609         const VkDevice                  device                                  = m_context.getDevice();
1610         Allocator&                              allocator                               = m_context.getDefaultAllocator();
1611         tcu::TestLog&                   log                                             = m_context.getTestContext().getLog();
1612
1613         deRandom rnd;
1614         deRandom_init(&rnd, m_data.seed);
1615
1616         vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
1617         deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
1618         subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
1619
1620         vk::VkPhysicalDeviceProperties2 properties2;
1621         deMemset(&properties2, 0, sizeof(properties2));
1622         properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
1623         properties2.pNext = &subgroupProperties;
1624
1625         m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
1626
1627         const deUint32 subgroupSize = subgroupProperties.subgroupSize;
1628         const deUint32 invocationStride = 128;
1629
1630         if (subgroupSize > 64)
1631                 TCU_THROW(TestError, "Subgroup size greater than 64 not handled.");
1632
1633         RandomProgram program(m_data);
1634         program.generateRandomProgram();
1635
1636         deUint32 maxLoc = program.simulate(true, subgroupSize, invocationStride, DE_NULL);
1637
1638         // maxLoc is per-invocation. Add one (to make sure no additional writes are done) and multiply by
1639         // the number of invocations
1640         maxLoc++;
1641         maxLoc *= invocationStride;
1642
1643         // buffer[0] is an input filled with a[i] == i
1644         // buffer[1] is the output
1645         // buffer[2] is the location counts
1646         de::MovePtr<BufferWithMemory> buffers[3];
1647         vk::VkDescriptorBufferInfo bufferDescriptors[3];
1648
1649         VkDeviceSize sizes[3] =
1650         {
1651                 128 * sizeof(deUint32),
1652                 maxLoc * sizeof(deUint64),
1653                 invocationStride * sizeof(deUint32),
1654         };
1655
1656         for (deUint32 i = 0; i < 3; ++i)
1657         {
1658                 try
1659                 {
1660                         buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1661                                 vk, device, allocator, makeBufferCreateInfo(sizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1662                                 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1663                 }
1664                 catch(tcu::ResourceError&)
1665                 {
1666                         // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1667                         return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[i]) + " bytes");
1668                 }
1669                 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, sizes[i]);
1670         }
1671
1672         deUint32 *ptrs[3];
1673         for (deUint32 i = 0; i < 3; ++i)
1674         {
1675                 ptrs[i] = (deUint32 *)buffers[i]->getAllocation().getHostPtr();
1676         }
1677         for (deUint32 i = 0; i < sizes[0] / sizeof(deUint32); ++i)
1678         {
1679                 ptrs[0][i] = i;
1680         }
1681         deMemset(ptrs[1], 0, (size_t)sizes[1]);
1682         deMemset(ptrs[2], 0, (size_t)sizes[2]);
1683
1684         vk::DescriptorSetLayoutBuilder layoutBuilder;
1685
1686         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1687         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1688         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1689
1690         vk::Unique<vk::VkDescriptorSetLayout>   descriptorSetLayout(layoutBuilder.build(vk, device));
1691
1692         vk::Unique<vk::VkDescriptorPool>                descriptorPool(vk::DescriptorPoolBuilder()
1693                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u)
1694                 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1695         vk::Unique<vk::VkDescriptorSet>                 descriptorSet           (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1696
1697         const deUint32 specData[1] =
1698         {
1699                 invocationStride,
1700         };
1701         const vk::VkSpecializationMapEntry entries[1] =
1702         {
1703                 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1704         };
1705         const vk::VkSpecializationInfo specInfo =
1706         {
1707                 1,                                              // mapEntryCount
1708                 entries,                                // pMapEntries
1709                 sizeof(specData),               // dataSize
1710                 specData                                // pData
1711         };
1712
1713         const VkPushConstantRange                               pushConstantRange                               =
1714         {
1715                 allShaderStages,                                                                                        // VkShaderStageFlags                                   stageFlags;
1716                 0u,                                                                                                                     // deUint32                                                             offset;
1717                 sizeof(deInt32)                                                                                         // deUint32                                                             size;
1718         };
1719
1720         const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
1721         {
1722                 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,                          // sType
1723                 DE_NULL,                                                                                                        // pNext
1724                 (VkPipelineLayoutCreateFlags)0,
1725                 1,                                                                                                                      // setLayoutCount
1726                 &descriptorSetLayout.get(),                                                                     // pSetLayouts
1727                 1u,                                                                                                                     // pushConstantRangeCount
1728                 &pushConstantRange,                                                                                     // pPushConstantRanges
1729         };
1730
1731         Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
1732
1733         VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1734
1735         flushAlloc(vk, device, buffers[0]->getAllocation());
1736         flushAlloc(vk, device, buffers[1]->getAllocation());
1737         flushAlloc(vk, device, buffers[2]->getAllocation());
1738
1739         const VkBool32 computeFullSubgroups = subgroupProperties.subgroupSize <= 64 &&
1740                                                                                   m_context.getSubgroupSizeControlFeaturesEXT().computeFullSubgroups;
1741
1742         const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo =
1743         {
1744                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType              sType;
1745                 DE_NULL,                                                                                                                                                // void*                                pNext;
1746                 subgroupProperties.subgroupSize                                                                                                 // uint32_t                             requiredSubgroupSize;
1747         };
1748
1749         const void *shaderPNext = computeFullSubgroups ? &subgroupSizeCreateInfo : DE_NULL;
1750         VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags =
1751                 (VkPipelineShaderStageCreateFlags)(computeFullSubgroups ? VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT : 0);
1752
1753         const Unique<VkShaderModule>                    shader                                          (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
1754         const VkPipelineShaderStageCreateInfo   shaderCreateInfo =
1755         {
1756                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1757                 shaderPNext,
1758                 pipelineShaderStageCreateFlags,
1759                 VK_SHADER_STAGE_COMPUTE_BIT,                                                            // stage
1760                 *shader,                                                                                                        // shader
1761                 "main",
1762                 &specInfo,                                                                                                      // pSpecializationInfo
1763         };
1764
1765         const VkComputePipelineCreateInfo               pipelineCreateInfo =
1766         {
1767                 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1768                 DE_NULL,
1769                 0u,                                                                                                                     // flags
1770                 shaderCreateInfo,                                                                                       // cs
1771                 *pipelineLayout,                                                                                        // layout
1772                 (vk::VkPipeline)0,                                                                                      // basePipelineHandle
1773                 0u,                                                                                                                     // basePipelineIndex
1774         };
1775         Move<VkPipeline> pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
1776
1777         const VkQueue                                   queue                                   = m_context.getUniversalQueue();
1778         Move<VkCommandPool>                             cmdPool                                 = createCommandPool(vk, device, vk::VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, m_context.getUniversalQueueFamilyIndex());
1779         Move<VkCommandBuffer>                   cmdBuffer                               = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1780
1781
1782         vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1783         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1784                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1785         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1786                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1787         setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1788                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1789         setUpdateBuilder.update(vk, device);
1790
1791         // compute "maxLoc", the maximum number of locations written
1792         beginCommandBuffer(vk, *cmdBuffer, 0u);
1793
1794         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1795         vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1796
1797         deInt32 pcinvocationStride = 0;
1798         vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1799
1800         vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1801
1802         endCommandBuffer(vk, *cmdBuffer);
1803
1804         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1805
1806         invalidateAlloc(vk, device, buffers[1]->getAllocation());
1807         invalidateAlloc(vk, device, buffers[2]->getAllocation());
1808
1809         // Clear any writes to buffer[1] during the counting pass
1810         deMemset(ptrs[1], 0, invocationStride * sizeof(deUint64));
1811
1812         // Take the max over all invocations. Add one (to make sure no additional writes are done) and multiply by
1813         // the number of invocations
1814         deUint32 newMaxLoc = 0;
1815         for (deUint32 id = 0; id < invocationStride; ++id)
1816                 newMaxLoc = de::max(newMaxLoc, ptrs[2][id]);
1817         newMaxLoc++;
1818         newMaxLoc *= invocationStride;
1819
1820         // If we need more space, reallocate buffers[1]
1821         if (newMaxLoc > maxLoc)
1822         {
1823                 maxLoc = newMaxLoc;
1824                 sizes[1] = maxLoc * sizeof(deUint64);
1825
1826                 try
1827                 {
1828                         buffers[1] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1829                                 vk, device, allocator, makeBufferCreateInfo(sizes[1], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1830                                 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1831                 }
1832                 catch(tcu::ResourceError&)
1833                 {
1834                         // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1835                         return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[1]) + " bytes");
1836                 }
1837                 bufferDescriptors[1] = makeDescriptorBufferInfo(**buffers[1], 0, sizes[1]);
1838                 ptrs[1] = (deUint32 *)buffers[1]->getAllocation().getHostPtr();
1839                 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1840
1841                 vk::DescriptorSetUpdateBuilder setUpdateBuilder2;
1842                 setUpdateBuilder2.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1843                         VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1844                 setUpdateBuilder2.update(vk, device);
1845         }
1846
1847         flushAlloc(vk, device, buffers[1]->getAllocation());
1848
1849         // run the actual shader
1850         beginCommandBuffer(vk, *cmdBuffer, 0u);
1851
1852         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1853         vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1854
1855         pcinvocationStride = invocationStride;
1856         vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1857
1858         vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1859
1860         endCommandBuffer(vk, *cmdBuffer);
1861
1862         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1863
1864         invalidateAlloc(vk, device, buffers[1]->getAllocation());
1865
1866         qpTestResult res = QP_TEST_RESULT_PASS;
1867
1868         // Simulate execution on the CPU, and compare against the GPU result
1869         deUint64 *ref = new deUint64 [maxLoc];
1870         deMemset(ref, 0, maxLoc*sizeof(deUint64));
1871         program.simulate(false, subgroupSize, invocationStride, ref);
1872
1873         const deUint64 *result = (const deUint64 *)ptrs[1];
1874
1875         if (m_data.testType == TT_MAXIMAL)
1876         {
1877                 // With maximal reconvergence, we should expect the output to exactly match
1878                 // the reference.
1879                 for (deUint32 i = 0; i < maxLoc; ++i)
1880                 {
1881                         if (result[i] != ref[i])
1882                         {
1883                                 log << tcu::TestLog::Message << "first mismatch at " << i << tcu::TestLog::EndMessage;
1884                                 res = QP_TEST_RESULT_FAIL;
1885                                 break;
1886                         }
1887                 }
1888
1889                 if (res != QP_TEST_RESULT_PASS)
1890                 {
1891                         for (deUint32 i = 0; i < maxLoc; ++i)
1892                         {
1893                                 // This log can be large and slow, ifdef it out by default
1894 #if 0
1895                                 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (result[i] != ref[i] ? " different" : "") << tcu::TestLog::EndMessage;
1896 #endif
1897                         }
1898                 }
1899         }
1900         else
1901         {
1902                 deUint64 fullMask = subgroupSizeToMask(subgroupSize);
1903                 // For subgroup_uniform_control_flow, we expect any fully converged outputs in the reference
1904                 // to have a corresponding fully converged output in the result. So walk through each lane's
1905                 // results, and for each reference value of fullMask, find a corresponding result value of
1906                 // fullMask where the previous value (OP_STORE) matches. That means these came from the same
1907                 // source location.
1908                 vector<deUint32> firstFail(invocationStride, 0);
1909                 for (deUint32 lane = 0; lane < invocationStride; ++lane)
1910                 {
1911                         deUint32 resLoc = lane + invocationStride, refLoc = lane + invocationStride;
1912                         while (refLoc < maxLoc)
1913                         {
1914                                 while (refLoc < maxLoc && ref[refLoc] != fullMask)
1915                                         refLoc += invocationStride;
1916                                 if (refLoc >= maxLoc)
1917                                         break;
1918
1919                                 // For TT_SUCF_ELECT, when the reference result has a full mask, we expect lane 0 to be elected
1920                                 // (a value of 2) and all other lanes to be not elected (a value of 1). For TT_SUCF_BALLOT, we
1921                                 // expect a full mask. Search until we find the expected result with a matching store value in
1922                                 // the previous result.
1923                                 deUint64 expectedResult = m_data.isElect() ? ((lane % subgroupSize) == 0 ? 2 : 1)
1924                                                                                                                          : fullMask;
1925
1926                                 while (resLoc < maxLoc && !(result[resLoc] == expectedResult && result[resLoc-invocationStride] == ref[refLoc-invocationStride]))
1927                                         resLoc += invocationStride;
1928
1929                                 // If we didn't find this output in the result, flag it as an error.
1930                                 if (resLoc >= maxLoc)
1931                                 {
1932                                         firstFail[lane] = refLoc;
1933                                         log << tcu::TestLog::Message << "lane " << lane << " first mismatch at " << firstFail[lane] << tcu::TestLog::EndMessage;
1934                                         res = QP_TEST_RESULT_FAIL;
1935                                         break;
1936                                 }
1937                                 refLoc += invocationStride;
1938                                 resLoc += invocationStride;
1939                         }
1940                 }
1941
1942                 if (res != QP_TEST_RESULT_PASS)
1943                 {
1944                         for (deUint32 i = 0; i < maxLoc; ++i)
1945                         {
1946                                 // This log can be large and slow, ifdef it out by default
1947 #if 0
1948                                 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (i == firstFail[i%invocationStride] ? " first fail" : "") << tcu::TestLog::EndMessage;
1949 #endif
1950                         }
1951                 }
1952         }
1953
1954         delete []ref;
1955
1956         return tcu::TestStatus(res, qpGetTestResultName(res));
1957 }
1958
1959 }       // anonymous
1960
1961 tcu::TestCaseGroup*     createTests (tcu::TestContext& testCtx, bool createExperimental)
1962 {
1963         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1964                         testCtx, "reconvergence", "reconvergence tests"));
1965
1966         typedef struct
1967         {
1968                 deUint32                                value;
1969                 const char*                             name;
1970                 const char*                             description;
1971         } TestGroupCase;
1972
1973         TestGroupCase ttCases[] =
1974         {
1975                 { TT_SUCF_ELECT,                                "subgroup_uniform_control_flow_elect",  "subgroup_uniform_control_flow_elect"           },
1976                 { TT_SUCF_BALLOT,                               "subgroup_uniform_control_flow_ballot", "subgroup_uniform_control_flow_ballot"          },
1977                 { TT_WUCF_ELECT,                                "workgroup_uniform_control_flow_elect", "workgroup_uniform_control_flow_elect"          },
1978                 { TT_WUCF_BALLOT,                               "workgroup_uniform_control_flow_ballot","workgroup_uniform_control_flow_ballot"         },
1979                 { TT_MAXIMAL,                                   "maximal",                                                              "maximal"                                                                       },
1980         };
1981
1982         for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1983         {
1984                 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1985                 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(testCtx, "compute", ""));
1986
1987                 for (deUint32 nNdx = 2; nNdx <= 6; nNdx++)
1988                 {
1989                         de::MovePtr<tcu::TestCaseGroup> nestGroup(new tcu::TestCaseGroup(testCtx, ("nesting" + de::toString(nNdx)).c_str(), ""));
1990
1991                         deUint32 seed = 0;
1992
1993                         for (int sNdx = 0; sNdx < 8; sNdx++)
1994                         {
1995                                 de::MovePtr<tcu::TestCaseGroup> seedGroup(new tcu::TestCaseGroup(testCtx, de::toString(sNdx).c_str(), ""));
1996
1997                                 deUint32 numTests = 0;
1998                                 switch (nNdx)
1999                                 {
2000                                 default:
2001                                         DE_ASSERT(0);
2002                                         // fallthrough
2003                                 case 2:
2004                                 case 3:
2005                                 case 4:
2006                                         numTests = 250;
2007                                         break;
2008                                 case 5:
2009                                         numTests = 100;
2010                                         break;
2011                                 case 6:
2012                                         numTests = 50;
2013                                         break;
2014                                 }
2015
2016                                 if (ttCases[ttNdx].value != TT_MAXIMAL)
2017                                 {
2018                                         if (nNdx >= 5)
2019                                                 continue;
2020                                 }
2021
2022                                 for (deUint32 ndx = 0; ndx < numTests; ndx++)
2023                                 {
2024                                         CaseDef c =
2025                                         {
2026                                                 (TestType)ttCases[ttNdx].value,         // TestType testType;
2027                                                 nNdx,                                                           // deUint32 maxNesting;
2028                                                 seed,                                                           // deUint32 seed;
2029                                         };
2030                                         seed++;
2031
2032                                         bool isExperimentalTest = !c.isUCF() || (ndx >= numTests / 5);
2033
2034                                         if (createExperimental == isExperimentalTest)
2035                                                 seedGroup->addChild(new ReconvergenceTestCase(testCtx, de::toString(ndx).c_str(), "", c));
2036                                 }
2037                                 if (!seedGroup->empty())
2038                                         nestGroup->addChild(seedGroup.release());
2039                         }
2040                         if (!nestGroup->empty())
2041                                 computeGroup->addChild(nestGroup.release());
2042                 }
2043                 if (!computeGroup->empty())
2044                 {
2045                         ttGroup->addChild(computeGroup.release());
2046                         group->addChild(ttGroup.release());
2047                 }
2048         }
2049         return group.release();
2050 }
2051
2052 }       // Reconvergence
2053 }       // vkt