1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2020 NVIDIA Corporation
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
12 * http://www.apache.org/licenses/LICENSE-2.0
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
22 * \brief Vulkan Reconvergence tests
23 *//*--------------------------------------------------------------------*/
25 #include "vktReconvergenceTests.hpp"
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
39 #include "deFloat16.h"
42 #include "deSharedPtr.hpp"
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
56 namespace Reconvergence
63 #define ARRAYSIZE(x) (sizeof(x) / sizeof(x[0]))
65 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
68 TT_SUCF_ELECT, // subgroup_uniform_control_flow using elect (subgroup_basic)
69 TT_SUCF_BALLOT, // subgroup_uniform_control_flow using ballot (subgroup_ballot)
70 TT_WUCF_ELECT, // workgroup uniform control flow using elect (subgroup_basic)
71 TT_WUCF_BALLOT, // workgroup uniform control flow using ballot (subgroup_ballot)
72 TT_MAXIMAL, // maximal reconvergence
81 bool isWUCF() const { return testType == TT_WUCF_ELECT || testType == TT_WUCF_BALLOT; }
82 bool isSUCF() const { return testType == TT_SUCF_ELECT || testType == TT_SUCF_BALLOT; }
83 bool isUCF() const { return isWUCF() || isSUCF(); }
84 bool isElect() const { return testType == TT_WUCF_ELECT || testType == TT_SUCF_ELECT; }
87 deUint64 subgroupSizeToMask(deUint32 subgroupSize)
89 if (subgroupSize == 64)
92 return (1ULL << subgroupSize) - 1;
95 typedef std::bitset<128> bitset128;
97 // Take a 64-bit integer, mask it to the subgroup size, and then
98 // replicate it for each subgroup
99 bitset128 bitsetFromU64(deUint64 mask, deUint32 subgroupSize)
101 mask &= subgroupSizeToMask(subgroupSize);
102 bitset128 result(mask);
103 for (deUint32 i = 0; i < 128 / subgroupSize - 1; ++i)
105 result = (result << subgroupSize) | bitset128(mask);
110 // Pick out the mask for the subgroup that invocationID is a member of
111 deUint64 bitsetToU64(const bitset128 &bitset, deUint32 subgroupSize, deUint32 invocationID)
113 bitset128 copy(bitset);
114 copy >>= (invocationID / subgroupSize) * subgroupSize;
115 copy &= bitset128(subgroupSizeToMask(subgroupSize));
116 deUint64 mask = copy.to_ullong();
117 mask &= subgroupSizeToMask(subgroupSize);
121 class ReconvergenceTestInstance : public TestInstance
124 ReconvergenceTestInstance (Context& context, const CaseDef& data);
125 ~ReconvergenceTestInstance (void);
126 tcu::TestStatus iterate (void);
131 ReconvergenceTestInstance::ReconvergenceTestInstance (Context& context, const CaseDef& data)
132 : vkt::TestInstance (context)
137 ReconvergenceTestInstance::~ReconvergenceTestInstance (void)
141 class ReconvergenceTestCase : public TestCase
144 ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
145 ~ReconvergenceTestCase (void);
146 virtual void initPrograms (SourceCollections& programCollection) const;
147 virtual TestInstance* createInstance (Context& context) const;
148 virtual void checkSupport (Context& context) const;
154 ReconvergenceTestCase::ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
155 : vkt::TestCase (context, name, desc)
160 ReconvergenceTestCase::~ReconvergenceTestCase (void)
164 void ReconvergenceTestCase::checkSupport(Context& context) const
166 if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
167 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
169 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
170 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
171 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
173 vk::VkPhysicalDeviceProperties2 properties2;
174 deMemset(&properties2, 0, sizeof(properties2));
175 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
176 properties2.pNext = &subgroupProperties;
178 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties2);
180 if (m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BASIC_BIT))
181 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BASIC_BIT not supported");
183 if (!m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT))
184 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BALLOT_BIT not supported");
186 if (!(context.getSubgroupProperties().supportedStages & VK_SHADER_STAGE_COMPUTE_BIT))
187 TCU_THROW(NotSupportedError, "compute stage does not support subgroup operations");
189 // Both subgroup- AND workgroup-uniform tests are enabled by shaderSubgroupUniformControlFlow.
190 if (m_data.isUCF() && !context.getShaderSubgroupUniformControlFlowFeatures().shaderSubgroupUniformControlFlow)
191 TCU_THROW(NotSupportedError, "shaderSubgroupUniformControlFlow not supported");
193 // XXX TODO: Check for maximal reconvergence support
194 // if (m_data.testType == TT_MAXIMAL ...)
199 // store subgroupBallot().
200 // For OP_BALLOT, OP::caseValue is initialized to zero, and then
201 // set to 1 by simulate if the ballot is not workgroup- (or subgroup-_uniform.
202 // Only workgroup-uniform ballots are validated for correctness in
206 // store literal constant
209 // if ((1ULL << gl_SubgroupInvocationID) & mask).
210 // Special case if mask = ~0ULL, converted into "if (inputA.a[idx] == idx)"
215 // if (gl_SubgroupInvocationID == loopIdxN) (where N is most nested loop counter)
219 // if (gl_LocalInvocationIndex >= inputA.a[N]) (where N is most nested loop counter)
220 OP_IF_LOCAL_INVOCATION_INDEX,
221 OP_ELSE_LOCAL_INVOCATION_INDEX,
227 // if (subgroupElect())
230 // Loop with uniform number of iterations (read from a buffer)
234 // for (int loopIdxN = 0; loopIdxN < gl_SubgroupInvocationID + 1; ++loopIdxN)
238 // for (int loopIdxN = 0;; ++loopIdxN, OP_BALLOT)
239 // Always has an "if (subgroupElect()) break;" inside.
240 // Does the equivalent of OP_BALLOT in the continue construct
244 // do { loopIdxN++; ... } while (loopIdxN < uniformValue);
245 OP_BEGIN_DO_WHILE_UNIF,
246 OP_END_DO_WHILE_UNIF,
248 // do { ... } while (true);
249 // Always has an "if (subgroupElect()) break;" inside
250 OP_BEGIN_DO_WHILE_INF,
256 // function call (code bracketed by these is extracted into a separate function)
260 // switch statement on uniform value
261 OP_SWITCH_UNIF_BEGIN,
262 // switch statement on gl_SubgroupInvocationID & 3 value
264 // switch statement on loopIdx value
265 OP_SWITCH_LOOP_COUNT_BEGIN,
267 // case statement with a (invocation mask, case mask) pair
269 // case statement used for loop counter switches, with a value and a mask of loop iterations
270 OP_CASE_LOOP_COUNT_BEGIN,
272 // end of switch/case statement
276 // Extra code with no functional effect. Currently inculdes:
277 // - value 0: while (!subgroupElect()) {}
278 // - value 1: if (condition_that_is_false) { infinite loop }
284 // Different if test conditions
288 IF_LOCAL_INVOCATION_INDEX,
294 OP(OPType _type, deUint64 _value, deUint32 _caseValue = 0)
295 : type(_type), value(_value), caseValue(_caseValue)
298 // The type of operation and an optional value.
299 // The value could be a mask for an if test, the index of the loop
300 // header for an end of loop, or the constant value for a store instruction
306 static int findLSB (deUint64 value)
308 for (int i = 0; i < 64; i++)
310 if (value & (1ULL<<i))
316 // For each subgroup, pick out the elected invocationID, and accumulate
317 // a bitset of all of them
318 static bitset128 bitsetElect (const bitset128& value, deInt32 subgroupSize)
320 bitset128 ret; // zero initialized
322 for (deInt32 i = 0; i < 128; i += subgroupSize)
324 deUint64 mask = bitsetToU64(value, subgroupSize, i);
325 int lsb = findLSB(mask);
326 ret |= bitset128(lsb == -1 ? 0 : (1ULL << lsb)) << i;
334 RandomProgram(const CaseDef &c)
335 : caseDef(c), numMasks(5), nesting(0), maxNesting(c.maxNesting), loopNesting(0), loopNestingThisFunction(0), callNesting(0), minCount(30), indent(0), isLoopInf(100, false), doneInfLoopBreak(100, false), storeBase(0x10000)
337 deRandom_init(&rnd, caseDef.seed);
338 for (int i = 0; i < numMasks; ++i)
339 masks.push_back(deRandom_getUint64(&rnd));
342 const CaseDef caseDef;
345 vector<deUint64> masks;
350 deInt32 loopNestingThisFunction;
354 vector<bool> isLoopInf;
355 vector<bool> doneInfLoopBreak;
356 // Offset the value we use for OP_STORE, to avoid colliding with fully converged
357 // active masks with small subgroup sizes (e.g. with subgroupSize == 4, the SUCF
358 // tests need to know that 0xF is really an active mask).
361 void genIf(IFType ifType)
363 deUint32 maskIdx = deRandom_getUint32(&rnd) % numMasks;
364 deUint64 mask = masks[maskIdx];
365 if (ifType == IF_UNIFORM)
368 deUint32 localIndexCmp = deRandom_getUint32(&rnd) % 128;
369 if (ifType == IF_LOCAL_INVOCATION_INDEX)
370 ops.push_back({OP_IF_LOCAL_INVOCATION_INDEX, localIndexCmp});
371 else if (ifType == IF_LOOPCOUNT)
372 ops.push_back({OP_IF_LOOPCOUNT, 0});
374 ops.push_back({OP_IF_MASK, mask});
378 size_t thenBegin = ops.size();
380 size_t thenEnd = ops.size();
382 deUint32 randElse = (deRandom_getUint32(&rnd) % 100);
385 if (ifType == IF_LOCAL_INVOCATION_INDEX)
386 ops.push_back({OP_ELSE_LOCAL_INVOCATION_INDEX, localIndexCmp});
387 else if (ifType == IF_LOOPCOUNT)
388 ops.push_back({OP_ELSE_LOOPCOUNT, 0});
390 ops.push_back({OP_ELSE_MASK, 0});
394 // Sometimes make the else block identical to the then block
395 for (size_t i = thenBegin; i < thenEnd; ++i)
396 ops.push_back(ops[i]);
401 ops.push_back({OP_ENDIF, 0});
407 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
408 ops.push_back({OP_BEGIN_FOR_UNIF, iterCount});
409 deUint32 loopheader = (deUint32)ops.size()-1;
412 loopNestingThisFunction++;
414 ops.push_back({OP_END_FOR_UNIF, loopheader});
415 loopNestingThisFunction--;
420 void genDoWhileUnif()
422 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
423 ops.push_back({OP_BEGIN_DO_WHILE_UNIF, iterCount});
424 deUint32 loopheader = (deUint32)ops.size()-1;
427 loopNestingThisFunction++;
429 ops.push_back({OP_END_DO_WHILE_UNIF, loopheader});
430 loopNestingThisFunction--;
437 ops.push_back({OP_BEGIN_FOR_VAR, 0});
438 deUint32 loopheader = (deUint32)ops.size()-1;
441 loopNestingThisFunction++;
443 ops.push_back({OP_END_FOR_VAR, loopheader});
444 loopNestingThisFunction--;
451 ops.push_back({OP_BEGIN_FOR_INF, 0});
452 deUint32 loopheader = (deUint32)ops.size()-1;
456 loopNestingThisFunction++;
457 isLoopInf[loopNesting] = true;
458 doneInfLoopBreak[loopNesting] = false;
463 doneInfLoopBreak[loopNesting] = true;
467 ops.push_back({OP_END_FOR_INF, loopheader});
469 isLoopInf[loopNesting] = false;
470 doneInfLoopBreak[loopNesting] = false;
471 loopNestingThisFunction--;
478 ops.push_back({OP_BEGIN_DO_WHILE_INF, 0});
479 deUint32 loopheader = (deUint32)ops.size()-1;
483 loopNestingThisFunction++;
484 isLoopInf[loopNesting] = true;
485 doneInfLoopBreak[loopNesting] = false;
490 doneInfLoopBreak[loopNesting] = true;
494 ops.push_back({OP_END_DO_WHILE_INF, loopheader});
496 isLoopInf[loopNesting] = false;
497 doneInfLoopBreak[loopNesting] = false;
498 loopNestingThisFunction--;
505 if (loopNestingThisFunction > 0)
507 // Sometimes put the break in a divergent if
508 if ((deRandom_getUint32(&rnd) % 100) < 10)
510 ops.push_back({OP_IF_MASK, masks[0]});
511 ops.push_back({OP_BREAK, 0});
512 ops.push_back({OP_ELSE_MASK, 0});
513 ops.push_back({OP_BREAK, 0});
514 ops.push_back({OP_ENDIF, 0});
517 ops.push_back({OP_BREAK, 0});
523 // continues are allowed if we're in a loop and the loop is not infinite,
524 // or if it is infinite and we've already done a subgroupElect+break.
525 // However, adding more continues seems to reduce the failure rate, so
526 // disable it for now
527 if (loopNestingThisFunction > 0 && !(isLoopInf[loopNesting] /*&& !doneInfLoopBreak[loopNesting]*/))
529 // Sometimes put the continue in a divergent if
530 if ((deRandom_getUint32(&rnd) % 100) < 10)
532 ops.push_back({OP_IF_MASK, masks[0]});
533 ops.push_back({OP_CONTINUE, 0});
534 ops.push_back({OP_ELSE_MASK, 0});
535 ops.push_back({OP_CONTINUE, 0});
536 ops.push_back({OP_ENDIF, 0});
539 ops.push_back({OP_CONTINUE, 0});
543 // doBreak is used to generate "if (subgroupElect()) { ... break; }" inside infinite loops
544 void genElect(bool doBreak)
546 ops.push_back({OP_ELECT, 0});
550 // Put something interestign before the break
553 if ((deRandom_getUint32(&rnd) % 100) < 10)
556 // if we're in a function, sometimes use return instead
557 if (callNesting > 0 && (deRandom_getUint32(&rnd) % 100) < 30)
558 ops.push_back({OP_RETURN, 0});
566 ops.push_back({OP_ENDIF, 0});
572 deUint32 r = deRandom_getUint32(&rnd) % 100;
574 // Use return rarely in main, 20% of the time in a singly nested loop in a function
575 // and 50% of the time in a multiply nested loop in a function
577 (callNesting > 0 && loopNestingThisFunction > 0 && r < 20) ||
578 (callNesting > 0 && loopNestingThisFunction > 1 && r < 50)))
581 if ((deRandom_getUint32(&rnd) % 100) < 10)
583 ops.push_back({OP_IF_MASK, masks[0]});
584 ops.push_back({OP_RETURN, 0});
585 ops.push_back({OP_ELSE_MASK, 0});
586 ops.push_back({OP_RETURN, 0});
587 ops.push_back({OP_ENDIF, 0});
590 ops.push_back({OP_RETURN, 0});
594 // Generate a function call. Save and restore some loop information, which is used to
595 // determine when it's safe to use break/continue
598 ops.push_back({OP_CALL_BEGIN, 0});
601 deInt32 saveLoopNestingThisFunction = loopNestingThisFunction;
602 loopNestingThisFunction = 0;
606 loopNestingThisFunction = saveLoopNestingThisFunction;
609 ops.push_back({OP_CALL_END, 0});
612 // Generate switch on a uniform value:
613 // switch (inputA.a[r]) {
614 // case r+1: ... break; // should not execute
615 // case r: ... break; // should branch uniformly
616 // case r+2: ... break; // should not execute
620 deUint32 r = deRandom_getUint32(&rnd) % 5;
621 ops.push_back({OP_SWITCH_UNIF_BEGIN, r});
624 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+1)});
626 ops.push_back({OP_CASE_END, 0});
628 ops.push_back({OP_CASE_MASK_BEGIN, ~0ULL, 1u<<r});
630 ops.push_back({OP_CASE_END, 0});
632 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+2)});
634 ops.push_back({OP_CASE_END, 0});
636 ops.push_back({OP_SWITCH_END, 0});
640 // switch (gl_SubgroupInvocationID & 3) with four unique targets
643 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
646 ops.push_back({OP_CASE_MASK_BEGIN, 0x1111111111111111ULL, 1<<0});
648 ops.push_back({OP_CASE_END, 0});
650 ops.push_back({OP_CASE_MASK_BEGIN, 0x2222222222222222ULL, 1<<1});
652 ops.push_back({OP_CASE_END, 0});
654 ops.push_back({OP_CASE_MASK_BEGIN, 0x4444444444444444ULL, 1<<2});
656 ops.push_back({OP_CASE_END, 0});
658 ops.push_back({OP_CASE_MASK_BEGIN, 0x8888888888888888ULL, 1<<3});
660 ops.push_back({OP_CASE_END, 0});
662 ops.push_back({OP_SWITCH_END, 0});
666 // switch (gl_SubgroupInvocationID & 3) with two shared targets.
667 // XXX TODO: The test considers these two targets to remain converged,
668 // though we haven't agreed to that behavior yet.
669 void genSwitchMulticase()
671 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
674 ops.push_back({OP_CASE_MASK_BEGIN, 0x3333333333333333ULL, (1<<0)|(1<<1)});
676 ops.push_back({OP_CASE_END, 0});
678 ops.push_back({OP_CASE_MASK_BEGIN, 0xCCCCCCCCCCCCCCCCULL, (1<<2)|(1<<3)});
680 ops.push_back({OP_CASE_END, 0});
682 ops.push_back({OP_SWITCH_END, 0});
686 // switch (loopIdxN) {
687 // case 1: ... break;
688 // case 2: ... break;
689 // default: ... break;
691 void genSwitchLoopCount()
693 deUint32 r = deRandom_getUint32(&rnd) % loopNesting;
694 ops.push_back({OP_SWITCH_LOOP_COUNT_BEGIN, r});
697 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<1, 1});
699 ops.push_back({OP_CASE_END, 0});
701 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<2, 2});
703 ops.push_back({OP_CASE_END, 0});
706 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, ~6ULL, 0xFFFFFFFF});
708 ops.push_back({OP_CASE_END, 0});
710 ops.push_back({OP_SWITCH_END, 0});
714 void pickOP(deUint32 count)
716 // Pick "count" instructions. These can recursively insert more instructions,
717 // so "count" is just a seed
718 for (deUint32 i = 0; i < count; ++i)
721 if (nesting < maxNesting)
723 deUint32 r = deRandom_getUint32(&rnd) % 11;
737 genIf(IF_LOCAL_INVOCATION_INDEX);
747 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
748 if (loopNesting <= 3)
750 deUint32 r2 = deRandom_getUint32(&rnd) % 3;
753 default: DE_ASSERT(0); // fallthrough
754 case 0: genForUnif(); break;
755 case 1: genForInf(); break;
756 case 2: genForVar(); break;
772 deUint32 r2 = deRandom_getUint32(&rnd) % 5;
773 if (r2 == 0 && callNesting == 0 && nesting < maxNesting - 2)
781 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
782 if (loopNesting <= 3)
784 deUint32 r2 = deRandom_getUint32(&rnd) % 2;
787 default: DE_ASSERT(0); // fallthrough
788 case 0: genDoWhileUnif(); break;
789 case 1: genDoWhileInf(); break;
796 deUint32 r2 = deRandom_getUint32(&rnd) % 4;
806 if (loopNesting > 0) {
807 genSwitchLoopCount();
812 if (caseDef.testType != TT_MAXIMAL)
814 // multicase doesn't have fully-defined behavior for MAXIMAL tests,
815 // but does for SUCF tests
816 genSwitchMulticase();
834 // optionally insert ballots, stores, and noise. Ballots and stores are used to determine
836 if ((deRandom_getUint32(&rnd) % 100) < 20)
838 if (ops.size() < 2 ||
839 !(ops[ops.size()-1].type == OP_BALLOT ||
840 (ops[ops.size()-1].type == OP_STORE && ops[ops.size()-2].type == OP_BALLOT)))
842 // do a store along with each ballot, so we can correlate where
843 // the ballot came from
844 if (caseDef.testType != TT_MAXIMAL)
845 ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
846 ops.push_back({OP_BALLOT, 0});
850 if ((deRandom_getUint32(&rnd) % 100) < 10)
852 if (ops.size() < 2 ||
853 !(ops[ops.size()-1].type == OP_STORE ||
854 (ops[ops.size()-1].type == OP_BALLOT && ops[ops.size()-2].type == OP_STORE)))
856 // SUCF does a store with every ballot. Don't bloat the code by adding more.
857 if (caseDef.testType == TT_MAXIMAL)
858 ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
862 deUint32 r = deRandom_getUint32(&rnd) % 10000;
864 ops.push_back({OP_NOISE, 0});
866 ops.push_back({OP_NOISE, 1});
869 void generateRandomProgram()
873 while ((deInt32)ops.size() < minCount)
876 // Retry until the program has some UCF results in it
879 const deUint32 invocationStride = 128;
880 // Simulate for all subgroup sizes, to determine whether OP_BALLOTs are nonuniform
881 for (deInt32 subgroupSize = 4; subgroupSize <= 64; subgroupSize *= 2) {
882 simulate(true, subgroupSize, invocationStride, DE_NULL);
885 } while (caseDef.isUCF() && !hasUCF());
888 void printIndent(std::stringstream &css)
890 for (deInt32 i = 0; i < indent; ++i)
894 std::string genPartitionBallot()
896 std::stringstream ss;
897 ss << "subgroupBallot(true).xy";
901 void printBallot(std::stringstream *css)
903 *css << "outputC.loc[gl_LocalInvocationIndex]++,";
904 // When inside loop(s), use partitionBallot rather than subgroupBallot to compute
905 // a ballot, to make sure the ballot is "diverged enough". Don't do this for
906 // subgroup_uniform_control_flow, since we only validate results that must be fully
908 if (loopNesting > 0 && caseDef.testType == TT_MAXIMAL)
910 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = " << genPartitionBallot();
912 else if (caseDef.isElect())
914 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = elect()";
918 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = subgroupBallot(true).xy";
922 void genCode(std::stringstream &functions, std::stringstream &main)
924 std::stringstream *css = &main;
928 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
934 if (ops[i].value == ~0ULL)
936 // This equality test will always succeed, since inputA.a[i] == i
937 int idx = deRandom_getUint32(&rnd) % 4;
938 *css << "if (inputA.a[" << idx << "] == " << idx << ") {\n";
941 *css << "if (testBit(uvec2(0x" << std::hex << (ops[i].value & 0xFFFFFFFF) << ", 0x" << (ops[i].value >> 32) << "), gl_SubgroupInvocationID)) {\n";
945 case OP_IF_LOOPCOUNT:
946 printIndent(*css); *css << "if (gl_SubgroupInvocationID == loopIdx" << loopNesting - 1 << ") {\n";
949 case OP_IF_LOCAL_INVOCATION_INDEX:
950 printIndent(*css); *css << "if (gl_LocalInvocationIndex >= inputA.a[0x" << std::hex << ops[i].value << "]) {\n";
954 case OP_ELSE_LOOPCOUNT:
955 case OP_ELSE_LOCAL_INVOCATION_INDEX:
957 printIndent(*css); *css << "} else {\n";
962 printIndent(*css); *css << "}\n";
965 printIndent(*css); printBallot(css); *css << ";\n";
968 printIndent(*css); *css << "outputC.loc[gl_LocalInvocationIndex]++;\n";
969 printIndent(*css); *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = 0x" << std::hex << ops[i].value << ";\n";
971 case OP_BEGIN_FOR_UNIF:
972 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
973 printIndent(*css); *css << " loopIdx" << loopNesting << " < inputA.a[" << ops[i].value << "];\n";
974 printIndent(*css); *css << " loopIdx" << loopNesting << "++) {\n";
978 case OP_END_FOR_UNIF:
981 printIndent(*css); *css << "}\n";
983 case OP_BEGIN_DO_WHILE_UNIF:
984 printIndent(*css); *css << "{\n";
986 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
987 printIndent(*css); *css << "do {\n";
989 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
992 case OP_BEGIN_DO_WHILE_INF:
993 printIndent(*css); *css << "{\n";
995 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
996 printIndent(*css); *css << "do {\n";
1000 case OP_END_DO_WHILE_UNIF:
1003 printIndent(*css); *css << "} while (loopIdx" << loopNesting << " < inputA.a[" << ops[(deUint32)ops[i].value].value << "]);\n";
1005 printIndent(*css); *css << "}\n";
1007 case OP_END_DO_WHILE_INF:
1009 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
1011 printIndent(*css); *css << "} while (true);\n";
1013 printIndent(*css); *css << "}\n";
1015 case OP_BEGIN_FOR_VAR:
1016 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
1017 printIndent(*css); *css << " loopIdx" << loopNesting << " < gl_SubgroupInvocationID + 1;\n";
1018 printIndent(*css); *css << " loopIdx" << loopNesting << "++) {\n";
1022 case OP_END_FOR_VAR:
1025 printIndent(*css); *css << "}\n";
1027 case OP_BEGIN_FOR_INF:
1028 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;;loopIdx" << loopNesting << "++,";
1034 case OP_END_FOR_INF:
1037 printIndent(*css); *css << "}\n";
1040 printIndent(*css); *css << "break;\n";
1043 printIndent(*css); *css << "continue;\n";
1046 printIndent(*css); *css << "if (subgroupElect()) {\n";
1050 printIndent(*css); *css << "return;\n";
1053 printIndent(*css); *css << "func" << funcNum << "(";
1054 for (deInt32 n = 0; n < loopNesting; ++n)
1056 *css << "loopIdx" << n;
1057 if (n != loopNesting - 1)
1062 printIndent(*css); *css << "void func" << funcNum << "(";
1063 for (deInt32 n = 0; n < loopNesting; ++n)
1065 *css << "int loopIdx" << n;
1066 if (n != loopNesting - 1)
1075 printIndent(*css); *css << "}\n";
1079 if (ops[i].value == 0)
1081 printIndent(*css); *css << "while (!subgroupElect()) {}\n";
1085 printIndent(*css); *css << "if (inputA.a[0] == 12345) {\n";
1087 printIndent(*css); *css << "while (true) {\n";
1089 printIndent(*css); printBallot(css); *css << ";\n";
1091 printIndent(*css); *css << "}\n";
1093 printIndent(*css); *css << "}\n";
1096 case OP_SWITCH_UNIF_BEGIN:
1097 printIndent(*css); *css << "switch (inputA.a[" << ops[i].value << "]) {\n";
1100 case OP_SWITCH_VAR_BEGIN:
1101 printIndent(*css); *css << "switch (gl_SubgroupInvocationID & 3) {\n";
1104 case OP_SWITCH_LOOP_COUNT_BEGIN:
1105 printIndent(*css); *css << "switch (loopIdx" << ops[i].value << ") {\n";
1110 printIndent(*css); *css << "}\n";
1112 case OP_CASE_MASK_BEGIN:
1113 for (deInt32 b = 0; b < 32; ++b)
1115 if ((1u<<b) & ops[i].caseValue)
1117 printIndent(*css); *css << "case " << b << ":\n";
1120 printIndent(*css); *css << "{\n";
1123 case OP_CASE_LOOP_COUNT_BEGIN:
1124 if (ops[i].caseValue == 0xFFFFFFFF)
1126 printIndent(*css); *css << "default: {\n";
1130 printIndent(*css); *css << "case " << ops[i].caseValue << ": {\n";
1135 printIndent(*css); *css << "break;\n";
1137 printIndent(*css); *css << "}\n";
1146 // Simulate execution of the program. If countOnly is true, just return
1147 // the max number of outputs written. If it's false, store out the result
1149 deUint32 simulate(bool countOnly, deUint32 subgroupSize, deUint32 invocationStride, deUint64 *ref)
1151 // State of the subgroup at each level of nesting
1152 struct SubgroupState
1154 // Currently executing
1155 bitset128 activeMask;
1156 // Have executed a continue instruction in this loop
1157 bitset128 continueMask;
1158 // index of the current if test or loop header
1160 // number of loop iterations performed
1162 // is this nesting a loop?
1164 // is this nesting a function call?
1166 // is this nesting a switch?
1169 SubgroupState stateStack[10];
1170 deMemset(&stateStack, 0, sizeof(stateStack));
1172 const deUint64 fullSubgroupMask = subgroupSizeToMask(subgroupSize);
1174 // Per-invocation output location counters
1175 deUint32 outLoc[128] = {0};
1179 stateStack[nesting].activeMask = ~bitset128(); // initialized to ~0
1182 while (i < (deInt32)ops.size())
1184 switch (ops[i].type)
1188 // Flag that this ballot is workgroup-nonuniform
1189 if (caseDef.isWUCF() && stateStack[nesting].activeMask.any() && !stateStack[nesting].activeMask.all())
1190 ops[i].caseValue = 1;
1192 if (caseDef.isSUCF())
1194 for (deUint32 id = 0; id < 128; id += subgroupSize)
1196 deUint64 subgroupMask = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1197 // Flag that this ballot is subgroup-nonuniform
1198 if (subgroupMask != 0 && subgroupMask != fullSubgroupMask)
1199 ops[i].caseValue = 1;
1203 for (deUint32 id = 0; id < 128; ++id)
1205 if (stateStack[nesting].activeMask.test(id))
1213 if (ops[i].caseValue)
1215 // Emit a magic value to indicate that we shouldn't validate this ballot
1216 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id);
1219 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1225 for (deUint32 id = 0; id < 128; ++id)
1227 if (stateStack[nesting].activeMask.test(id))
1232 ref[(outLoc[id]++)*invocationStride + id] = ops[i].value;
1238 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1239 stateStack[nesting].header = i;
1240 stateStack[nesting].isLoop = 0;
1241 stateStack[nesting].isSwitch = 0;
1244 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64(ops[stateStack[nesting].header].value, subgroupSize);
1246 case OP_IF_LOOPCOUNT:
1248 deUint32 n = nesting;
1249 while (!stateStack[n].isLoop)
1253 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1254 stateStack[nesting].header = i;
1255 stateStack[nesting].isLoop = 0;
1256 stateStack[nesting].isSwitch = 0;
1259 case OP_ELSE_LOOPCOUNT:
1261 deUint32 n = nesting;
1262 while (!stateStack[n].isLoop)
1265 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1268 case OP_IF_LOCAL_INVOCATION_INDEX:
1272 for (deInt32 j = (deInt32)ops[i].value; j < 128; ++j)
1276 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1277 stateStack[nesting].header = i;
1278 stateStack[nesting].isLoop = 0;
1279 stateStack[nesting].isSwitch = 0;
1282 case OP_ELSE_LOCAL_INVOCATION_INDEX:
1286 for (deInt32 j = 0; j < (deInt32)ops[i].value; ++j)
1289 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1295 case OP_BEGIN_FOR_UNIF:
1296 // XXX TODO: We don't handle a for loop with zero iterations
1299 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1300 stateStack[nesting].header = i;
1301 stateStack[nesting].tripCount = 0;
1302 stateStack[nesting].isLoop = 1;
1303 stateStack[nesting].isSwitch = 0;
1304 stateStack[nesting].continueMask = 0;
1306 case OP_END_FOR_UNIF:
1307 stateStack[nesting].tripCount++;
1308 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1309 stateStack[nesting].continueMask = 0;
1310 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1311 stateStack[nesting].activeMask.any())
1313 i = stateStack[nesting].header+1;
1322 case OP_BEGIN_DO_WHILE_UNIF:
1323 // XXX TODO: We don't handle a for loop with zero iterations
1326 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1327 stateStack[nesting].header = i;
1328 stateStack[nesting].tripCount = 1;
1329 stateStack[nesting].isLoop = 1;
1330 stateStack[nesting].isSwitch = 0;
1331 stateStack[nesting].continueMask = 0;
1333 case OP_END_DO_WHILE_UNIF:
1334 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1335 stateStack[nesting].continueMask = 0;
1336 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1337 stateStack[nesting].activeMask.any())
1339 i = stateStack[nesting].header+1;
1340 stateStack[nesting].tripCount++;
1349 case OP_BEGIN_FOR_VAR:
1350 // XXX TODO: We don't handle a for loop with zero iterations
1353 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1354 stateStack[nesting].header = i;
1355 stateStack[nesting].tripCount = 0;
1356 stateStack[nesting].isLoop = 1;
1357 stateStack[nesting].isSwitch = 0;
1358 stateStack[nesting].continueMask = 0;
1360 case OP_END_FOR_VAR:
1361 stateStack[nesting].tripCount++;
1362 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1363 stateStack[nesting].continueMask = 0;
1364 stateStack[nesting].activeMask &= bitsetFromU64(stateStack[nesting].tripCount == subgroupSize ? 0 : ~((1ULL << (stateStack[nesting].tripCount)) - 1), subgroupSize);
1365 if (stateStack[nesting].activeMask.any())
1367 i = stateStack[nesting].header+1;
1376 case OP_BEGIN_FOR_INF:
1377 case OP_BEGIN_DO_WHILE_INF:
1380 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1381 stateStack[nesting].header = i;
1382 stateStack[nesting].tripCount = 0;
1383 stateStack[nesting].isLoop = 1;
1384 stateStack[nesting].isSwitch = 0;
1385 stateStack[nesting].continueMask = 0;
1387 case OP_END_FOR_INF:
1388 stateStack[nesting].tripCount++;
1389 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1390 stateStack[nesting].continueMask = 0;
1391 if (stateStack[nesting].activeMask.any())
1393 // output expected OP_BALLOT values
1394 for (deUint32 id = 0; id < 128; ++id)
1396 if (stateStack[nesting].activeMask.test(id))
1401 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1405 i = stateStack[nesting].header+1;
1414 case OP_END_DO_WHILE_INF:
1415 stateStack[nesting].tripCount++;
1416 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1417 stateStack[nesting].continueMask = 0;
1418 if (stateStack[nesting].activeMask.any())
1420 i = stateStack[nesting].header+1;
1431 deUint32 n = nesting;
1432 bitset128 mask = stateStack[nesting].activeMask;
1435 stateStack[n].activeMask &= ~mask;
1436 if (stateStack[n].isLoop || stateStack[n].isSwitch)
1445 deUint32 n = nesting;
1446 bitset128 mask = stateStack[nesting].activeMask;
1449 stateStack[n].activeMask &= ~mask;
1450 if (stateStack[n].isLoop)
1452 stateStack[n].continueMask |= mask;
1462 stateStack[nesting].activeMask = bitsetElect(stateStack[nesting-1].activeMask, subgroupSize);
1463 stateStack[nesting].header = i;
1464 stateStack[nesting].isLoop = 0;
1465 stateStack[nesting].isSwitch = 0;
1470 bitset128 mask = stateStack[nesting].activeMask;
1471 for (deInt32 n = nesting; n >= 0; --n)
1473 stateStack[n].activeMask &= ~mask;
1474 if (stateStack[n].isCall)
1482 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1483 stateStack[nesting].isLoop = 0;
1484 stateStack[nesting].isSwitch = 0;
1485 stateStack[nesting].isCall = 1;
1488 stateStack[nesting].isCall = 0;
1494 case OP_SWITCH_UNIF_BEGIN:
1495 case OP_SWITCH_VAR_BEGIN:
1496 case OP_SWITCH_LOOP_COUNT_BEGIN:
1498 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1499 stateStack[nesting].header = i;
1500 stateStack[nesting].isLoop = 0;
1501 stateStack[nesting].isSwitch = 1;
1506 case OP_CASE_MASK_BEGIN:
1507 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1509 case OP_CASE_LOOP_COUNT_BEGIN:
1511 deUint32 n = nesting;
1512 deUint32 l = loopNesting;
1516 if (stateStack[n].isLoop)
1519 if (l == ops[stateStack[nesting].header].value)
1525 if ((1ULL << stateStack[n].tripCount) & ops[i].value)
1526 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1528 stateStack[nesting].activeMask = 0;
1540 deUint32 maxLoc = 0;
1541 for (deUint32 id = 0; id < ARRAYSIZE(outLoc); ++id)
1542 maxLoc = de::max(maxLoc, outLoc[id]);
1549 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
1551 if (ops[i].type == OP_BALLOT && ops[i].caseValue == 0)
1558 void ReconvergenceTestCase::initPrograms (SourceCollections& programCollection) const
1560 RandomProgram program(m_data);
1561 program.generateRandomProgram();
1563 std::stringstream css;
1564 css << "#version 450 core\n";
1565 css << "#extension GL_KHR_shader_subgroup_ballot : enable\n";
1566 css << "#extension GL_KHR_shader_subgroup_vote : enable\n";
1567 css << "#extension GL_NV_shader_subgroup_partitioned : enable\n";
1568 css << "#extension GL_EXT_subgroup_uniform_control_flow : enable\n";
1569 css << "layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n";
1570 css << "layout(set=0, binding=0) coherent buffer InputA { uint a[]; } inputA;\n";
1571 css << "layout(set=0, binding=1) coherent buffer OutputB { uvec2 b[]; } outputB;\n";
1572 css << "layout(set=0, binding=2) coherent buffer OutputC { uint loc[]; } outputC;\n";
1573 css << "layout(push_constant) uniform PC {\n"
1574 " // set to the real stride when writing out ballots, or zero when just counting\n"
1575 " int invocationStride;\n"
1577 css << "int outLoc = 0;\n";
1579 css << "bool testBit(uvec2 mask, uint bit) { return (bit < 32) ? ((mask.x >> bit) & 1) != 0 : ((mask.y >> (bit-32)) & 1) != 0; }\n";
1581 css << "uint elect() { return int(subgroupElect()) + 1; }\n";
1583 std::stringstream functions, main;
1584 program.genCode(functions, main);
1586 css << functions.str() << "\n\n";
1590 << (m_data.isSUCF() ? "[[subgroup_uniform_control_flow]]\n" : "") <<
1593 css << main.str() << "\n\n";
1597 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
1599 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
1602 TestInstance* ReconvergenceTestCase::createInstance (Context& context) const
1604 return new ReconvergenceTestInstance(context, m_data);
1607 tcu::TestStatus ReconvergenceTestInstance::iterate (void)
1609 const DeviceInterface& vk = m_context.getDeviceInterface();
1610 const VkDevice device = m_context.getDevice();
1611 Allocator& allocator = m_context.getDefaultAllocator();
1612 tcu::TestLog& log = m_context.getTestContext().getLog();
1615 deRandom_init(&rnd, m_data.seed);
1617 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
1618 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
1619 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
1621 vk::VkPhysicalDeviceProperties2 properties2;
1622 deMemset(&properties2, 0, sizeof(properties2));
1623 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
1624 properties2.pNext = &subgroupProperties;
1626 m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
1628 const deUint32 subgroupSize = subgroupProperties.subgroupSize;
1629 const deUint32 invocationStride = 128;
1631 if (subgroupSize > 64)
1632 TCU_THROW(TestError, "Subgroup size greater than 64 not handled.");
1634 RandomProgram program(m_data);
1635 program.generateRandomProgram();
1637 deUint32 maxLoc = program.simulate(true, subgroupSize, invocationStride, DE_NULL);
1639 // maxLoc is per-invocation. Add one (to make sure no additional writes are done) and multiply by
1640 // the number of invocations
1642 maxLoc *= invocationStride;
1644 // buffer[0] is an input filled with a[i] == i
1645 // buffer[1] is the output
1646 // buffer[2] is the location counts
1647 de::MovePtr<BufferWithMemory> buffers[3];
1648 vk::VkDescriptorBufferInfo bufferDescriptors[3];
1650 VkDeviceSize sizes[3] =
1652 128 * sizeof(deUint32),
1653 maxLoc * sizeof(deUint64),
1654 invocationStride * sizeof(deUint32),
1657 for (deUint32 i = 0; i < 3; ++i)
1661 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1662 vk, device, allocator, makeBufferCreateInfo(sizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1663 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1665 catch(tcu::ResourceError&)
1667 // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1668 return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[i]) + " bytes");
1670 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, sizes[i]);
1674 for (deUint32 i = 0; i < 3; ++i)
1676 ptrs[i] = (deUint32 *)buffers[i]->getAllocation().getHostPtr();
1678 for (deUint32 i = 0; i < sizes[0] / sizeof(deUint32); ++i)
1682 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1683 deMemset(ptrs[2], 0, (size_t)sizes[2]);
1685 vk::DescriptorSetLayoutBuilder layoutBuilder;
1687 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1688 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1689 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1691 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
1693 vk::Unique<vk::VkDescriptorPool> descriptorPool(vk::DescriptorPoolBuilder()
1694 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u)
1695 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1696 vk::Unique<vk::VkDescriptorSet> descriptorSet (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1698 const deUint32 specData[1] =
1702 const vk::VkSpecializationMapEntry entries[1] =
1704 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1706 const vk::VkSpecializationInfo specInfo =
1709 entries, // pMapEntries
1710 sizeof(specData), // dataSize
1714 const VkPushConstantRange pushConstantRange =
1716 allShaderStages, // VkShaderStageFlags stageFlags;
1717 0u, // deUint32 offset;
1718 sizeof(deInt32) // deUint32 size;
1721 const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
1723 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
1725 (VkPipelineLayoutCreateFlags)0,
1726 1, // setLayoutCount
1727 &descriptorSetLayout.get(), // pSetLayouts
1728 1u, // pushConstantRangeCount
1729 &pushConstantRange, // pPushConstantRanges
1732 Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
1734 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1736 flushAlloc(vk, device, buffers[0]->getAllocation());
1737 flushAlloc(vk, device, buffers[1]->getAllocation());
1738 flushAlloc(vk, device, buffers[2]->getAllocation());
1740 const VkBool32 computeFullSubgroups = subgroupProperties.subgroupSize <= 64 &&
1741 m_context.getSubgroupSizeControlFeaturesEXT().computeFullSubgroups;
1743 const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo =
1745 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType sType;
1746 DE_NULL, // void* pNext;
1747 subgroupProperties.subgroupSize // uint32_t requiredSubgroupSize;
1750 const void *shaderPNext = computeFullSubgroups ? &subgroupSizeCreateInfo : DE_NULL;
1751 VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags =
1752 (VkPipelineShaderStageCreateFlags)(computeFullSubgroups ? VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT : 0);
1754 const Unique<VkShaderModule> shader (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
1755 const VkPipelineShaderStageCreateInfo shaderCreateInfo =
1757 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1759 pipelineShaderStageCreateFlags,
1760 VK_SHADER_STAGE_COMPUTE_BIT, // stage
1763 &specInfo, // pSpecializationInfo
1766 const VkComputePipelineCreateInfo pipelineCreateInfo =
1768 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1771 shaderCreateInfo, // cs
1772 *pipelineLayout, // layout
1773 (vk::VkPipeline)0, // basePipelineHandle
1774 0u, // basePipelineIndex
1776 Move<VkPipeline> pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
1778 const VkQueue queue = m_context.getUniversalQueue();
1779 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, vk::VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, m_context.getUniversalQueueFamilyIndex());
1780 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1783 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1784 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1785 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1786 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1787 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1788 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1789 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1790 setUpdateBuilder.update(vk, device);
1792 // compute "maxLoc", the maximum number of locations written
1793 beginCommandBuffer(vk, *cmdBuffer, 0u);
1795 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1796 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1798 deInt32 pcinvocationStride = 0;
1799 vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1801 vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1803 endCommandBuffer(vk, *cmdBuffer);
1805 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1807 invalidateAlloc(vk, device, buffers[1]->getAllocation());
1808 invalidateAlloc(vk, device, buffers[2]->getAllocation());
1810 // Clear any writes to buffer[1] during the counting pass
1811 deMemset(ptrs[1], 0, invocationStride * sizeof(deUint64));
1813 // Take the max over all invocations. Add one (to make sure no additional writes are done) and multiply by
1814 // the number of invocations
1815 deUint32 newMaxLoc = 0;
1816 for (deUint32 id = 0; id < invocationStride; ++id)
1817 newMaxLoc = de::max(newMaxLoc, ptrs[2][id]);
1819 newMaxLoc *= invocationStride;
1821 // If we need more space, reallocate buffers[1]
1822 if (newMaxLoc > maxLoc)
1825 sizes[1] = maxLoc * sizeof(deUint64);
1829 buffers[1] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1830 vk, device, allocator, makeBufferCreateInfo(sizes[1], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1831 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1833 catch(tcu::ResourceError&)
1835 // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1836 return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[1]) + " bytes");
1838 bufferDescriptors[1] = makeDescriptorBufferInfo(**buffers[1], 0, sizes[1]);
1839 ptrs[1] = (deUint32 *)buffers[1]->getAllocation().getHostPtr();
1840 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1842 vk::DescriptorSetUpdateBuilder setUpdateBuilder2;
1843 setUpdateBuilder2.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1844 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1845 setUpdateBuilder2.update(vk, device);
1848 flushAlloc(vk, device, buffers[1]->getAllocation());
1850 // run the actual shader
1851 beginCommandBuffer(vk, *cmdBuffer, 0u);
1853 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1854 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1856 pcinvocationStride = invocationStride;
1857 vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1859 vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1861 endCommandBuffer(vk, *cmdBuffer);
1863 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1865 invalidateAlloc(vk, device, buffers[1]->getAllocation());
1867 qpTestResult res = QP_TEST_RESULT_PASS;
1869 // Simulate execution on the CPU, and compare against the GPU result
1870 deUint64 *ref = new deUint64 [maxLoc];
1871 deMemset(ref, 0, maxLoc*sizeof(deUint64));
1872 program.simulate(false, subgroupSize, invocationStride, ref);
1874 const deUint64 *result = (const deUint64 *)ptrs[1];
1876 if (m_data.testType == TT_MAXIMAL)
1878 // With maximal reconvergence, we should expect the output to exactly match
1880 for (deUint32 i = 0; i < maxLoc; ++i)
1882 if (result[i] != ref[i])
1884 log << tcu::TestLog::Message << "first mismatch at " << i << tcu::TestLog::EndMessage;
1885 res = QP_TEST_RESULT_FAIL;
1890 if (res != QP_TEST_RESULT_PASS)
1892 for (deUint32 i = 0; i < maxLoc; ++i)
1894 // This log can be large and slow, ifdef it out by default
1896 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (result[i] != ref[i] ? " different" : "") << tcu::TestLog::EndMessage;
1903 deUint64 fullMask = subgroupSizeToMask(subgroupSize);
1904 // For subgroup_uniform_control_flow, we expect any fully converged outputs in the reference
1905 // to have a corresponding fully converged output in the result. So walk through each lane's
1906 // results, and for each reference value of fullMask, find a corresponding result value of
1907 // fullMask where the previous value (OP_STORE) matches. That means these came from the same
1909 vector<deUint32> firstFail(invocationStride, 0);
1910 for (deUint32 lane = 0; lane < invocationStride; ++lane)
1912 deUint32 resLoc = lane + invocationStride, refLoc = lane + invocationStride;
1913 while (refLoc < maxLoc)
1915 while (refLoc < maxLoc && ref[refLoc] != fullMask)
1916 refLoc += invocationStride;
1917 if (refLoc >= maxLoc)
1920 // For TT_SUCF_ELECT, when the reference result has a full mask, we expect lane 0 to be elected
1921 // (a value of 2) and all other lanes to be not elected (a value of 1). For TT_SUCF_BALLOT, we
1922 // expect a full mask. Search until we find the expected result with a matching store value in
1923 // the previous result.
1924 deUint64 expectedResult = m_data.isElect() ? ((lane % subgroupSize) == 0 ? 2 : 1)
1927 while (resLoc < maxLoc && !(result[resLoc] == expectedResult && result[resLoc-invocationStride] == ref[refLoc-invocationStride]))
1928 resLoc += invocationStride;
1930 // If we didn't find this output in the result, flag it as an error.
1931 if (resLoc >= maxLoc)
1933 firstFail[lane] = refLoc;
1934 log << tcu::TestLog::Message << "lane " << lane << " first mismatch at " << firstFail[lane] << tcu::TestLog::EndMessage;
1935 res = QP_TEST_RESULT_FAIL;
1938 refLoc += invocationStride;
1939 resLoc += invocationStride;
1943 if (res != QP_TEST_RESULT_PASS)
1945 for (deUint32 i = 0; i < maxLoc; ++i)
1947 // This log can be large and slow, ifdef it out by default
1949 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (i == firstFail[i%invocationStride] ? " first fail" : "") << tcu::TestLog::EndMessage;
1957 return tcu::TestStatus(res, qpGetTestResultName(res));
1962 tcu::TestCaseGroup* createTests (tcu::TestContext& testCtx, bool createExperimental)
1964 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1965 testCtx, "reconvergence", "reconvergence tests"));
1971 const char* description;
1974 TestGroupCase ttCases[] =
1976 { TT_SUCF_ELECT, "subgroup_uniform_control_flow_elect", "subgroup_uniform_control_flow_elect" },
1977 { TT_SUCF_BALLOT, "subgroup_uniform_control_flow_ballot", "subgroup_uniform_control_flow_ballot" },
1978 { TT_WUCF_ELECT, "workgroup_uniform_control_flow_elect", "workgroup_uniform_control_flow_elect" },
1979 { TT_WUCF_BALLOT, "workgroup_uniform_control_flow_ballot","workgroup_uniform_control_flow_ballot" },
1980 { TT_MAXIMAL, "maximal", "maximal" },
1983 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1985 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1986 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(testCtx, "compute", ""));
1988 for (deUint32 nNdx = 2; nNdx <= 6; nNdx++)
1990 de::MovePtr<tcu::TestCaseGroup> nestGroup(new tcu::TestCaseGroup(testCtx, ("nesting" + de::toString(nNdx)).c_str(), ""));
1994 for (int sNdx = 0; sNdx < 8; sNdx++)
1996 de::MovePtr<tcu::TestCaseGroup> seedGroup(new tcu::TestCaseGroup(testCtx, de::toString(sNdx).c_str(), ""));
1998 deUint32 numTests = 0;
2017 if (ttCases[ttNdx].value != TT_MAXIMAL)
2023 for (deUint32 ndx = 0; ndx < numTests; ndx++)
2027 (TestType)ttCases[ttNdx].value, // TestType testType;
2028 nNdx, // deUint32 maxNesting;
2029 seed, // deUint32 seed;
2033 bool isExperimentalTest = !c.isUCF() || (ndx >= numTests / 5);
2035 if (createExperimental == isExperimentalTest)
2036 seedGroup->addChild(new ReconvergenceTestCase(testCtx, de::toString(ndx).c_str(), "", c));
2038 if (!seedGroup->empty())
2039 nestGroup->addChild(seedGroup.release());
2041 if (!nestGroup->empty())
2042 computeGroup->addChild(nestGroup.release());
2044 if (!computeGroup->empty())
2046 ttGroup->addChild(computeGroup.release());
2047 group->addChild(ttGroup.release());
2050 return group.release();