106cad2b75b30bd00fdeb5c35c9ec1c3a63f7a92
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsBallotOtherTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBallotOtherTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_INVERSE_BALLOT = 0,
41         OPTYPE_BALLOT_BIT_EXTRACT,
42         OPTYPE_BALLOT_BIT_COUNT,
43         OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT,
44         OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT,
45         OPTYPE_BALLOT_FIND_LSB,
46         OPTYPE_BALLOT_FIND_MSB,
47         OPTYPE_LAST
48 };
49
50 static bool checkVertexPipelineStages(std::vector<const void*> datas,
51                                                                           deUint32 width, deUint32)
52 {
53         const deUint32* data =
54                 reinterpret_cast<const deUint32*>(datas[0]);
55         for (deUint32 x = 0; x < width; ++x)
56         {
57                 deUint32 val = data[x];
58
59                 if (0xf != val)
60                 {
61                         return false;
62                 }
63         }
64
65         return true;
66 }
67
68 static bool checkFragment(std::vector<const void*> datas,
69                                                   deUint32 width, deUint32 height, deUint32)
70 {
71         const deUint32* data =
72                 reinterpret_cast<const deUint32*>(datas[0]);
73         for (deUint32 x = 0; x < width; ++x)
74         {
75                 for (deUint32 y = 0; y < height; ++y)
76                 {
77                         deUint32 val = data[x * height + y];
78
79                         if (0xf != val)
80                         {
81                                 return false;
82                         }
83                 }
84         }
85
86         return true;
87 }
88
89 static bool checkCompute(std::vector<const void*> datas,
90                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
91                                                  deUint32)
92 {
93         const deUint32* data =
94                 reinterpret_cast<const deUint32*>(datas[0]);
95
96         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
97         {
98                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
99                 {
100                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
101                         {
102                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
103                                 {
104                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
105                                         {
106                                                 for (deUint32 lZ = 0; lZ < localSize[2];
107                                                                 ++lZ)
108                                                 {
109                                                         const deUint32 globalInvocationX =
110                                                                 nX * localSize[0] + lX;
111                                                         const deUint32 globalInvocationY =
112                                                                 nY * localSize[1] + lY;
113                                                         const deUint32 globalInvocationZ =
114                                                                 nZ * localSize[2] + lZ;
115
116                                                         const deUint32 globalSizeX =
117                                                                 numWorkgroups[0] * localSize[0];
118                                                         const deUint32 globalSizeY =
119                                                                 numWorkgroups[1] * localSize[1];
120
121                                                         const deUint32 offset =
122                                                                 globalSizeX *
123                                                                 ((globalSizeY *
124                                                                   globalInvocationZ) +
125                                                                  globalInvocationY) +
126                                                                 globalInvocationX;
127
128                                                         if (0xf != data[offset])
129                                                         {
130                                                                 return false;
131                                                         }
132                                                 }
133                                         }
134                                 }
135                         }
136                 }
137         }
138
139         return true;
140 }
141
142 std::string getOpTypeName(int opType)
143 {
144         switch (opType)
145         {
146                 default:
147                         DE_FATAL("Unsupported op type");
148                 case OPTYPE_INVERSE_BALLOT:
149                         return "subgroupInverseBallot";
150                 case OPTYPE_BALLOT_BIT_EXTRACT:
151                         return "subgroupBallotBitExtract";
152                 case OPTYPE_BALLOT_BIT_COUNT:
153                         return "subgroupBallotBitCount";
154                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
155                         return "subgroupBallotInclusiveBitCount";
156                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
157                         return "subgroupBallotExclusiveBitCount";
158                 case OPTYPE_BALLOT_FIND_LSB:
159                         return "subgroupBallotFindLSB";
160                 case OPTYPE_BALLOT_FIND_MSB:
161                         return "subgroupBallotFindMSB";
162         }
163 }
164
165 struct CaseDefinition
166 {
167         int                                     opType;
168         VkShaderStageFlags      shaderStage;
169         bool                            noSSBO;
170 };
171
172 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
173 {
174         std::ostringstream bdy;
175
176         bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
177                 << "  uvec4 allZeros = uvec4(0);\n"
178                 << "  uint tempResult = 0;\n"
179                 << "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
180                 << "i >= 32 ? 0 : (0xFFFFFFFF << i), "
181                 << "i >= 64 ? 0 : (0xFFFFFFFF << ((i < 32) ? 0 : (i - 32))), "
182                 << "i >= 96 ? 0 : (0xFFFFFFFF << ((i < 64) ? 0 : (i - 64))), "
183                 << " 0xFFFFFFFF << ((i < 96) ? 0 : (i - 96)))\n"
184                 << "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
185                 << "i >= 32 ? 0 : 0x1 << i, "
186                 << "i < 32 || i >= 64 ? 0 : 0x1 << (i - 32), "
187                 << "i < 64 || i >= 96 ? 0 : 0x1 << (i - 64), "
188                 << "i < 96 ? 0 : 0x1 << (i - 96))\n";
189
190         switch (caseDef.opType)
191         {
192                 default:
193                         DE_FATAL("Unknown op type!");
194                 case OPTYPE_INVERSE_BALLOT:
195                         bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1 : 0;\n"
196                                 << "  tempResult |= subgroupInverseBallot(allZeros) ? 0 : 0x2;\n"
197                                 << "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4 : 0;\n"
198                                 << "  tempResult |= 0x8;\n";
199                         break;
200                 case OPTYPE_BALLOT_BIT_EXTRACT:
201                         bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1 : 0;\n"
202                                 << "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0 : 0x2;\n"
203                                 << "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4 : 0;\n"
204                                 << "  tempResult |= 0x8;\n"
205                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
206                                 << "  {\n"
207                                 << "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
208                                 << "    {\n"
209                                 << "      tempResult &= ~0x8;\n"
210                                 << "    }\n"
211                                 << "  }\n";
212                         break;
213                 case OPTYPE_BALLOT_BIT_COUNT:
214                         bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1 : 0;\n"
215                                 << "  tempResult |= 0 == subgroupBallotBitCount(allZeros) ? 0x2 : 0;\n"
216                                 << "  tempResult |= 0 < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
217                                 << "  tempResult |= 0 == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8 : 0;\n";
218                         break;
219                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
220                         bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1;\n"
221                                 << "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1 : 0;\n"
222                                 << "  tempResult |= 0 == subgroupBallotInclusiveBitCount(allZeros) ? 0x2 : 0;\n"
223                                 << "  tempResult |= 0 < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
224                                 << "  tempResult |= 0x8;\n"
225                                 << "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
226                                 << "  bool undefTerritory = false;\n"
227                                 << "  for (uint i = 0; i <= 128; i++)\n"
228                                 << "  {\n"
229                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
230                                 << "    if (iUndef == inclusiveUndef)"
231                                 << "    {\n"
232                                 << "      undefTerritory = true;\n"
233                                 << "    }\n"
234                                 << "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
235                                 << "    if (undefTerritory && (0 != inclusiveBitCount))\n"
236                                 << "    {\n"
237                                 << "      tempResult &= ~0x8;\n"
238                                 << "    }\n"
239                                 << "    else if (!undefTerritory && (0 == inclusiveBitCount))\n"
240                                 << "    {\n"
241                                 << "      tempResult &= ~0x8;\n"
242                                 << "    }\n"
243                                 << "  }\n";
244                         break;
245                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
246                         bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
247                                 << "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1 : 0;\n"
248                                 << "  tempResult |= 0 == subgroupBallotExclusiveBitCount(allZeros) ? 0x2 : 0;\n"
249                                 << "  tempResult |= 0x4;\n"
250                                 << "  tempResult |= 0x8;\n"
251                                 << "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
252                                 << "  bool undefTerritory = false;\n"
253                                 << "  for (uint i = 0; i <= 128; i++)\n"
254                                 << "  {\n"
255                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
256                                 << "    if (iUndef == exclusiveUndef)"
257                                 << "    {\n"
258                                 << "      undefTerritory = true;\n"
259                                 << "    }\n"
260                                 << "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
261                                 << "    if (undefTerritory && (0 != exclusiveBitCount))\n"
262                                 << "    {\n"
263                                 << "      tempResult &= ~0x4;\n"
264                                 << "    }\n"
265                                 << "    else if (!undefTerritory && (0 == exclusiveBitCount))\n"
266                                 << "    {\n"
267                                 << "      tempResult &= ~0x8;\n"
268                                 << "    }\n"
269                                 << "  }\n";
270                         break;
271                 case OPTYPE_BALLOT_FIND_LSB:
272                         bdy << "  tempResult |= 0 == subgroupBallotFindLSB(allOnes) ? 0x1 : 0;\n"
273                                 << "  if (subgroupElect())\n"
274                                 << "  {\n"
275                                 << "    tempResult |= 0x2;\n"
276                                 << "  }\n"
277                                 << "  else\n"
278                                 << "  {\n"
279                                 << "    tempResult |= 0 < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2 : 0;\n"
280                                 << "  }\n"
281                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4 : 0;\n"
282                                 << "  tempResult |= 0x8;\n"
283                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
284                                 << "  {\n"
285                                 << "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
286                                 << "    {\n"
287                                 << "      tempResult &= ~0x8;\n"
288                                 << "    }\n"
289                                 << "  }\n";
290                         break;
291                 case OPTYPE_BALLOT_FIND_MSB:
292                         bdy << "  tempResult |= (gl_SubgroupSize - 1) == subgroupBallotFindMSB(allOnes) ? 0x1 : 0;\n"
293                                 << "  if (subgroupElect())\n"
294                                 << "  {\n"
295                                 << "    tempResult |= 0x2;\n"
296                                 << "  }\n"
297                                 << "  else\n"
298                                 << "  {\n"
299                                 << "    tempResult |= 0 < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2 : 0;\n"
300                                 << "  }\n"
301                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4 : 0;\n"
302                                 << "  tempResult |= 0x8;\n"
303                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
304                                 << "  {\n"
305                                 << "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
306                                 << "    {\n"
307                                 << "      tempResult &= ~0x8;\n"
308                                 << "    }\n"
309                                 << "  }\n";
310                         break;
311         }
312
313         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
314         {
315                 std::ostringstream src;
316                 std::ostringstream      fragmentSrc;
317
318                 src << "#version 450\n"
319                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
320                         << "layout(location = 0) in highp vec4 in_position;\n"
321                         << "layout(location = 0) out float out_color;\n"
322                         << "\n"
323                         << "void main (void)\n"
324                         << "{\n"
325                         << bdy.str()
326                         << "  out_color = float(tempResult);\n"
327                         << "  gl_Position = in_position;\n"
328                         << "}\n";
329
330                 programCollection.glslSources.add("vert") << glu::VertexSource(src.str());
331
332                 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
333                         << "layout(location = 0) in float in_color;\n"
334                         << "layout(location = 0) out uint out_color;\n"
335                         << "void main()\n"
336                         <<"{\n"
337                         << "    out_color = uint(in_color);\n"
338                         << "}\n";
339                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str());
340         }
341         else
342         {
343                 DE_FATAL("Unsupported shader stage");
344         }
345 }
346
347 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
348 {
349         std::ostringstream bdy;
350
351         bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
352                 << "  uvec4 allZeros = uvec4(0);\n"
353                 << "  uint tempResult = 0;\n"
354                 << "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
355                 << "i >= 32 ? 0 : (0xFFFFFFFF << i), "
356                 << "i >= 64 ? 0 : (0xFFFFFFFF << ((i < 32) ? 0 : (i - 32))), "
357                 << "i >= 96 ? 0 : (0xFFFFFFFF << ((i < 64) ? 0 : (i - 64))), "
358                 << " 0xFFFFFFFF << ((i < 96) ? 0 : (i - 96)))\n"
359                 << "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
360                 << "i >= 32 ? 0 : 0x1 << i, "
361                 << "i < 32 || i >= 64 ? 0 : 0x1 << (i - 32), "
362                 << "i < 64 || i >= 96 ? 0 : 0x1 << (i - 64), "
363                 << "i < 96 ? 0 : 0x1 << (i - 96))\n";
364
365         switch (caseDef.opType)
366         {
367                 default:
368                         DE_FATAL("Unknown op type!");
369                 case OPTYPE_INVERSE_BALLOT:
370                         bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1 : 0;\n"
371                                 << "  tempResult |= subgroupInverseBallot(allZeros) ? 0 : 0x2;\n"
372                                 << "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4 : 0;\n"
373                                 << "  tempResult |= 0x8;\n";
374                         break;
375                 case OPTYPE_BALLOT_BIT_EXTRACT:
376                         bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1 : 0;\n"
377                                 << "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0 : 0x2;\n"
378                                 << "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4 : 0;\n"
379                                 << "  tempResult |= 0x8;\n"
380                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
381                                 << "  {\n"
382                                 << "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
383                                 << "    {\n"
384                                 << "      tempResult &= ~0x8;\n"
385                                 << "    }\n"
386                                 << "  }\n";
387                         break;
388                 case OPTYPE_BALLOT_BIT_COUNT:
389                         bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1 : 0;\n"
390                                 << "  tempResult |= 0 == subgroupBallotBitCount(allZeros) ? 0x2 : 0;\n"
391                                 << "  tempResult |= 0 < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
392                                 << "  tempResult |= 0 == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8 : 0;\n";
393                         break;
394                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
395                         bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1;\n"
396                                 << "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1 : 0;\n"
397                                 << "  tempResult |= 0 == subgroupBallotInclusiveBitCount(allZeros) ? 0x2 : 0;\n"
398                                 << "  tempResult |= 0 < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
399                                 << "  tempResult |= 0x8;\n"
400                                 << "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
401                                 << "  bool undefTerritory = false;\n"
402                                 << "  for (uint i = 0; i <= 128; i++)\n"
403                                 << "  {\n"
404                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
405                                 << "    if (iUndef == inclusiveUndef)"
406                                 << "    {\n"
407                                 << "      undefTerritory = true;\n"
408                                 << "    }\n"
409                                 << "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
410                                 << "    if (undefTerritory && (0 != inclusiveBitCount))\n"
411                                 << "    {\n"
412                                 << "      tempResult &= ~0x8;\n"
413                                 << "    }\n"
414                                 << "    else if (!undefTerritory && (0 == inclusiveBitCount))\n"
415                                 << "    {\n"
416                                 << "      tempResult &= ~0x8;\n"
417                                 << "    }\n"
418                                 << "  }\n";
419                         break;
420                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
421                         bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
422                                 << "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1 : 0;\n"
423                                 << "  tempResult |= 0 == subgroupBallotExclusiveBitCount(allZeros) ? 0x2 : 0;\n"
424                                 << "  tempResult |= 0x4;\n"
425                                 << "  tempResult |= 0x8;\n"
426                                 << "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
427                                 << "  bool undefTerritory = false;\n"
428                                 << "  for (uint i = 0; i <= 128; i++)\n"
429                                 << "  {\n"
430                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
431                                 << "    if (iUndef == exclusiveUndef)"
432                                 << "    {\n"
433                                 << "      undefTerritory = true;\n"
434                                 << "    }\n"
435                                 << "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
436                                 << "    if (undefTerritory && (0 != exclusiveBitCount))\n"
437                                 << "    {\n"
438                                 << "      tempResult &= ~0x4;\n"
439                                 << "    }\n"
440                                 << "    else if (!undefTerritory && (0 == exclusiveBitCount))\n"
441                                 << "    {\n"
442                                 << "      tempResult &= ~0x8;\n"
443                                 << "    }\n"
444                                 << "  }\n";
445                         break;
446                 case OPTYPE_BALLOT_FIND_LSB:
447                         bdy << "  tempResult |= 0 == subgroupBallotFindLSB(allOnes) ? 0x1 : 0;\n"
448                                 << "  if (subgroupElect())\n"
449                                 << "  {\n"
450                                 << "    tempResult |= 0x2;\n"
451                                 << "  }\n"
452                                 << "  else\n"
453                                 << "  {\n"
454                                 << "    tempResult |= 0 < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2 : 0;\n"
455                                 << "  }\n"
456                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4 : 0;\n"
457                                 << "  tempResult |= 0x8;\n"
458                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
459                                 << "  {\n"
460                                 << "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
461                                 << "    {\n"
462                                 << "      tempResult &= ~0x8;\n"
463                                 << "    }\n"
464                                 << "  }\n";
465                         break;
466                 case OPTYPE_BALLOT_FIND_MSB:
467                         bdy << "  tempResult |= (gl_SubgroupSize - 1) == subgroupBallotFindMSB(allOnes) ? 0x1 : 0;\n"
468                                 << "  if (subgroupElect())\n"
469                                 << "  {\n"
470                                 << "    tempResult |= 0x2;\n"
471                                 << "  }\n"
472                                 << "  else\n"
473                                 << "  {\n"
474                                 << "    tempResult |= 0 < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2 : 0;\n"
475                                 << "  }\n"
476                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4 : 0;\n"
477                                 << "  tempResult |= 0x8;\n"
478                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
479                                 << "  {\n"
480                                 << "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
481                                 << "    {\n"
482                                 << "      tempResult &= ~0x8;\n"
483                                 << "    }\n"
484                                 << "  }\n";
485                         break;
486         }
487
488         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
489         {
490                 std::ostringstream src;
491
492                 src << "#version 450\n"
493                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
494                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
495                         "local_size_z_id = 2) in;\n"
496                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
497                         << "{\n"
498                         << "  uint result[];\n"
499                         << "};\n"
500                         << "\n"
501                         << "void main (void)\n"
502                         << "{\n"
503                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
504                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
505                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
506                         "gl_GlobalInvocationID.x;\n"
507                         << bdy.str()
508                         << "  result[offset] = tempResult;\n"
509                         << "}\n";
510
511                 programCollection.glslSources.add("comp")
512                                 << glu::ComputeSource(src.str());
513         }
514         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
515         {
516                 programCollection.glslSources.add("vert")
517                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage));
518
519                 std::ostringstream frag;
520
521                 frag << "#version 450\n"
522                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
523                          << "layout(location = 0) out uint result;\n"
524                          << "void main (void)\n"
525                          << "{\n"
526                          << bdy.str()
527                          << "  result = tempResult;\n"
528                          << "}\n";
529
530                 programCollection.glslSources.add("frag")
531                                 << glu::FragmentSource(frag.str());
532         }
533         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
534         {
535                 std::ostringstream src;
536
537                 src << "#version 450\n"
538                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
539                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
540                         << "{\n"
541                         << "  uint result[];\n"
542                         << "};\n"
543                         << "\n"
544                         << "void main (void)\n"
545                         << "{\n"
546                         << bdy.str()
547                         << "  result[gl_VertexIndex] = tempResult;\n"
548                         << "}\n";
549
550                 programCollection.glslSources.add("vert")
551                                 << glu::VertexSource(src.str());
552         }
553         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
554         {
555                 programCollection.glslSources.add("vert")
556                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage));
557
558                 std::ostringstream src;
559
560                 src << "#version 450\n"
561                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
562                         << "layout(points) in;\n"
563                         << "layout(points, max_vertices = 1) out;\n"
564                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
565                         << "{\n"
566                         << "  uint result[];\n"
567                         << "};\n"
568                         << "\n"
569                         << "void main (void)\n"
570                         << "{\n"
571                         << bdy.str()
572                         << "  result[gl_PrimitiveIDIn] = tempResult;\n"
573                         << "}\n";
574
575                 programCollection.glslSources.add("geom")
576                                 << glu::GeometrySource(src.str());
577         }
578         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
579         {
580                 programCollection.glslSources.add("vert")
581                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage));
582
583                 programCollection.glslSources.add("tese")
584                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
585
586                 std::ostringstream src;
587
588                 src << "#version 450\n"
589                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
590                         << "layout(vertices=1) out;\n"
591                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
592                         << "{\n"
593                         << "  uint result[];\n"
594                         << "};\n"
595                         << "\n"
596                         << "void main (void)\n"
597                         << "{\n"
598                         << bdy.str()
599                         << "  result[gl_PrimitiveID] = tempResult;\n"
600                         << "}\n";
601
602                 programCollection.glslSources.add("tesc")
603                                 << glu::TessellationControlSource(src.str());
604         }
605         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
606         {
607                 programCollection.glslSources.add("vert")
608                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage));
609
610                 programCollection.glslSources.add("tesc")
611                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
612
613                 std::ostringstream src;
614
615                 src << "#version 450\n"
616                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
617                         << "layout(isolines) in;\n"
618                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
619                         << "{\n"
620                         << "  uint result[];\n"
621                         << "};\n"
622                         << "\n"
623                         << "void main (void)\n"
624                         << "{\n"
625                         << bdy.str()
626                         << "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
627                         << "}\n";
628
629                 programCollection.glslSources.add("tese")
630                                 << glu::TessellationEvaluationSource(src.str());
631         }
632         else
633         {
634                 DE_FATAL("Unsupported shader stage");
635         }
636 }
637
638 tcu::TestStatus test (Context& context, const CaseDefinition caseDef)
639 {
640         if (!subgroups::isSubgroupSupported(context))
641                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
642
643         if (!subgroups::areSubgroupOperationsSupportedForStage(
644                                 context, caseDef.shaderStage))
645         {
646                 if (subgroups::areSubgroupOperationsRequiredForStage(
647                                         caseDef.shaderStage))
648                 {
649                         return tcu::TestStatus::fail(
650                                            "Shader stage " +
651                                 subgroups::getShaderStageName(caseDef.shaderStage) +
652                                 " is required to support subgroup operations!");
653                 }
654                 else
655                 {
656                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
657                 }
658         }
659
660         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
661         {
662                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
663         }
664
665         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
666         {
667                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT,
668                                                                                  DE_NULL, 0, checkVertexPipelineStages);
669         }
670
671         //Tests which don't use the SSBO
672         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
673                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
674         {
675                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
676                 {
677                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
678                 }
679         }
680
681         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
682         {
683                 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
684                                                                                    DE_NULL, 0, checkFragment);
685         }
686         else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
687         {
688                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
689                                                                                   DE_NULL, 0, checkCompute);
690         }
691         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
692         {
693                 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
694                                                                                  DE_NULL, 0, checkVertexPipelineStages);
695         }
696         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
697         {
698                 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
699                                                                                    DE_NULL, 0, checkVertexPipelineStages);
700         }
701         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
702         {
703                 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
704                                 DE_NULL, 0, checkVertexPipelineStages);
705         }
706         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
707         {
708                 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
709                                 DE_NULL, 0, checkVertexPipelineStages);
710         }
711         return tcu::TestStatus::pass("OK");
712 }
713 }
714
715 namespace vkt
716 {
717 namespace subgroups
718 {
719 tcu::TestCaseGroup* createSubgroupsBallotOtherTests(tcu::TestContext& testCtx)
720 {
721         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
722                         testCtx, "ballot_other", "Subgroup ballot other category tests"));
723
724         const VkShaderStageFlags stages[] =
725         {
726                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
727                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
728                 VK_SHADER_STAGE_GEOMETRY_BIT,
729                 VK_SHADER_STAGE_VERTEX_BIT,
730                 VK_SHADER_STAGE_FRAGMENT_BIT,
731                 VK_SHADER_STAGE_COMPUTE_BIT
732         };
733
734         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
735         {
736                 const VkShaderStageFlags stage = stages[stageIndex];
737
738                 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
739                 {
740                         CaseDefinition caseDef = {opTypeIndex, stage, false};
741
742                         std::ostringstream name;
743
744                         std::string op = getOpTypeName(opTypeIndex);
745
746                         name << de::toLower(op) << "_" << getShaderStageName(stage);
747
748                         addFunctionCaseWithPrograms(group.get(), name.str(),
749                                                                                 "", initPrograms, test, caseDef);
750
751                         if (VK_SHADER_STAGE_VERTEX_BIT & stage )
752                         {
753                                 caseDef.noSSBO = true;
754                                 addFunctionCaseWithPrograms(group.get(), name.str() + "_framebuffer", "",
755                                                                 initFrameBufferPrograms, test, caseDef);
756                         }
757
758                 }
759         }
760
761         return group.release();
762 }
763
764 } // subgroups
765 } // vkt