c16cbe0e01fd443b8cf95a1482df88c329daca3a
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsBallotOtherTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBallotOtherTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_INVERSE_BALLOT = 0,
41         OPTYPE_BALLOT_BIT_EXTRACT,
42         OPTYPE_BALLOT_BIT_COUNT,
43         OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT,
44         OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT,
45         OPTYPE_BALLOT_FIND_LSB,
46         OPTYPE_BALLOT_FIND_MSB,
47         OPTYPE_LAST
48 };
49
50 static bool checkVertexPipelineStages(std::vector<const void*> datas,
51                                                                           deUint32 width, deUint32)
52 {
53         return vkt::subgroups::check(datas, width, 0xf);
54 }
55
56 static bool checkCompute(std::vector<const void*> datas,
57                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
58                                                  deUint32)
59 {
60         return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0xf);
61 }
62
63 std::string getOpTypeName(int opType)
64 {
65         switch (opType)
66         {
67                 default:
68                         DE_FATAL("Unsupported op type");
69                         return "";
70                 case OPTYPE_INVERSE_BALLOT:
71                         return "subgroupInverseBallot";
72                 case OPTYPE_BALLOT_BIT_EXTRACT:
73                         return "subgroupBallotBitExtract";
74                 case OPTYPE_BALLOT_BIT_COUNT:
75                         return "subgroupBallotBitCount";
76                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
77                         return "subgroupBallotInclusiveBitCount";
78                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
79                         return "subgroupBallotExclusiveBitCount";
80                 case OPTYPE_BALLOT_FIND_LSB:
81                         return "subgroupBallotFindLSB";
82                 case OPTYPE_BALLOT_FIND_MSB:
83                         return "subgroupBallotFindMSB";
84         }
85 }
86
87 struct CaseDefinition
88 {
89         int                                     opType;
90         VkShaderStageFlags      shaderStage;
91 };
92
93 std::string getBodySource(CaseDefinition caseDef)
94 {
95         std::ostringstream bdy;
96
97         bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
98                 << "  uvec4 allZeros = uvec4(0);\n"
99                 << "  uint tempResult = 0;\n"
100                 << "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
101                 << "i >= 32 ? 0 : (0xFFFFFFFF << i), "
102                 << "i >= 64 ? 0 : (0xFFFFFFFF << ((i < 32) ? 0 : (i - 32))), "
103                 << "i >= 96 ? 0 : (0xFFFFFFFF << ((i < 64) ? 0 : (i - 64))), "
104                 << " 0xFFFFFFFF << ((i < 96) ? 0 : (i - 96)))\n"
105                 << "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
106                 << "i >= 32 ? 0 : 0x1 << i, "
107                 << "i < 32 || i >= 64 ? 0 : 0x1 << (i - 32), "
108                 << "i < 64 || i >= 96 ? 0 : 0x1 << (i - 64), "
109                 << "i < 96 ? 0 : 0x1 << (i - 96))\n";
110
111         switch (caseDef.opType)
112         {
113                 default:
114                         DE_FATAL("Unknown op type!");
115                         break;
116                 case OPTYPE_INVERSE_BALLOT:
117                         bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1 : 0;\n"
118                                 << "  tempResult |= subgroupInverseBallot(allZeros) ? 0 : 0x2;\n"
119                                 << "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4 : 0;\n"
120                                 << "  tempResult |= 0x8;\n";
121                         break;
122                 case OPTYPE_BALLOT_BIT_EXTRACT:
123                         bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1 : 0;\n"
124                                 << "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0 : 0x2;\n"
125                                 << "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4 : 0;\n"
126                                 << "  tempResult |= 0x8;\n"
127                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
128                                 << "  {\n"
129                                 << "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
130                                 << "    {\n"
131                                 << "      tempResult &= ~0x8;\n"
132                                 << "    }\n"
133                                 << "  }\n";
134                         break;
135                 case OPTYPE_BALLOT_BIT_COUNT:
136                         bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1 : 0;\n"
137                                 << "  tempResult |= 0 == subgroupBallotBitCount(allZeros) ? 0x2 : 0;\n"
138                                 << "  tempResult |= 0 < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
139                                 << "  tempResult |= 0 == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8 : 0;\n";
140                         break;
141                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
142                         bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1;\n"
143                                 << "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1 : 0;\n"
144                                 << "  tempResult |= 0 == subgroupBallotInclusiveBitCount(allZeros) ? 0x2 : 0;\n"
145                                 << "  tempResult |= 0 < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
146                                 << "  tempResult |= 0x8;\n"
147                                 << "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
148                                 << "  bool undefTerritory = false;\n"
149                                 << "  for (uint i = 0; i <= 128; i++)\n"
150                                 << "  {\n"
151                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
152                                 << "    if (iUndef == inclusiveUndef)"
153                                 << "    {\n"
154                                 << "      undefTerritory = true;\n"
155                                 << "    }\n"
156                                 << "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
157                                 << "    if (undefTerritory && (0 != inclusiveBitCount))\n"
158                                 << "    {\n"
159                                 << "      tempResult &= ~0x8;\n"
160                                 << "    }\n"
161                                 << "    else if (!undefTerritory && (0 == inclusiveBitCount))\n"
162                                 << "    {\n"
163                                 << "      tempResult &= ~0x8;\n"
164                                 << "    }\n"
165                                 << "  }\n";
166                         break;
167                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
168                         bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
169                                 << "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1 : 0;\n"
170                                 << "  tempResult |= 0 == subgroupBallotExclusiveBitCount(allZeros) ? 0x2 : 0;\n"
171                                 << "  tempResult |= 0x4;\n"
172                                 << "  tempResult |= 0x8;\n"
173                                 << "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
174                                 << "  bool undefTerritory = false;\n"
175                                 << "  for (uint i = 0; i <= 128; i++)\n"
176                                 << "  {\n"
177                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
178                                 << "    if (iUndef == exclusiveUndef)"
179                                 << "    {\n"
180                                 << "      undefTerritory = true;\n"
181                                 << "    }\n"
182                                 << "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
183                                 << "    if (undefTerritory && (0 != exclusiveBitCount))\n"
184                                 << "    {\n"
185                                 << "      tempResult &= ~0x4;\n"
186                                 << "    }\n"
187                                 << "    else if (!undefTerritory && (0 == exclusiveBitCount))\n"
188                                 << "    {\n"
189                                 << "      tempResult &= ~0x8;\n"
190                                 << "    }\n"
191                                 << "  }\n";
192                         break;
193                 case OPTYPE_BALLOT_FIND_LSB:
194                         bdy << "  tempResult |= 0 == subgroupBallotFindLSB(allOnes) ? 0x1 : 0;\n"
195                                 << "  if (subgroupElect())\n"
196                                 << "  {\n"
197                                 << "    tempResult |= 0x2;\n"
198                                 << "  }\n"
199                                 << "  else\n"
200                                 << "  {\n"
201                                 << "    tempResult |= 0 < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2 : 0;\n"
202                                 << "  }\n"
203                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4 : 0;\n"
204                                 << "  tempResult |= 0x8;\n"
205                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
206                                 << "  {\n"
207                                 << "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
208                                 << "    {\n"
209                                 << "      tempResult &= ~0x8;\n"
210                                 << "    }\n"
211                                 << "  }\n";
212                         break;
213                 case OPTYPE_BALLOT_FIND_MSB:
214                         bdy << "  tempResult |= (gl_SubgroupSize - 1) == subgroupBallotFindMSB(allOnes) ? 0x1 : 0;\n"
215                                 << "  if (subgroupElect())\n"
216                                 << "  {\n"
217                                 << "    tempResult |= 0x2;\n"
218                                 << "  }\n"
219                                 << "  else\n"
220                                 << "  {\n"
221                                 << "    tempResult |= 0 < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2 : 0;\n"
222                                 << "  }\n"
223                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4 : 0;\n"
224                                 << "  tempResult |= 0x8;\n"
225                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
226                                 << "  {\n"
227                                 << "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
228                                 << "    {\n"
229                                 << "      tempResult &= ~0x8;\n"
230                                 << "    }\n"
231                                 << "  }\n";
232                         break;
233         }
234    return bdy.str();
235 }
236
237 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
238 {
239         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
240
241         subgroups::setFragmentShaderFrameBuffer(programCollection);
242
243         if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
244                 subgroups::setVertexShaderFrameBuffer(programCollection);
245
246         std::string bdyStr = getBodySource(caseDef);
247
248         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
249         {
250                 std::ostringstream                              vertex;
251                 vertex << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
253                         << "layout(location = 0) in highp vec4 in_position;\n"
254                         << "layout(location = 0) out float out_color;\n"
255                         << "\n"
256                         << "void main (void)\n"
257                         << "{\n"
258                         << bdyStr
259                         << "  out_color = float(tempResult);\n"
260                         << "  gl_Position = in_position;\n"
261                         << "  gl_PointSize = 1.0f;\n"
262                         << "}\n";
263                 programCollection.glslSources.add("vert")
264                         << glu::VertexSource(vertex.str()) << buildOptions;
265         }
266         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
267         {
268                 std::ostringstream geometry;
269
270                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
271                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
272                         << "layout(points) in;\n"
273                         << "layout(points, max_vertices = 1) out;\n"
274                         << "layout(location = 0) out float out_color;\n"
275                         << "void main (void)\n"
276                         << "{\n"
277                         << bdyStr
278                         << "  out_color = float(tempResult);\n"
279                         << "  gl_Position = gl_in[0].gl_Position;\n"
280                         << "  EmitVertex();\n"
281                         << "  EndPrimitive();\n"
282                         << "}\n";
283
284                 programCollection.glslSources.add("geometry")
285                         << glu::GeometrySource(geometry.str()) << buildOptions;
286         }
287         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
288         {
289                 std::ostringstream controlSource;
290
291                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
292                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
293                         << "layout(vertices = 2) out;\n"
294                         << "layout(location = 0) out float out_color[];\n"
295                         << "\n"
296                         << "void main (void)\n"
297                         << "{\n"
298                         << "  if (gl_InvocationID == 0)\n"
299                         << "  {\n"
300                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
301                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
302                         << "  }\n"
303                         << bdyStr
304                         << "  out_color[gl_InvocationID ] = float(tempResult);\n"
305                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
306                         << "}\n";
307
308                 programCollection.glslSources.add("tesc")
309                         << glu::TessellationControlSource(controlSource.str()) << buildOptions;
310                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
311         }
312         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
313         {
314                 std::ostringstream evaluationSource;
315                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
316                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
317                         << "layout(isolines, equal_spacing, ccw ) in;\n"
318                         << "layout(location = 0) out float out_color;\n"
319                         << "void main (void)\n"
320                         << "{\n"
321                         << bdyStr
322                         << "  out_color  = float(tempResult);\n"
323                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
324                         << "}\n";
325
326                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
327                 programCollection.glslSources.add("tese")
328                         << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
329         }
330         else
331         {
332                 DE_FATAL("Unsupported shader stage");
333         }
334 }
335
336 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
337 {
338         std::string bdyStr = getBodySource(caseDef);
339
340         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
341         {
342                 std::ostringstream src;
343
344                 src << "#version 450\n"
345                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
346                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
347                         "local_size_z_id = 2) in;\n"
348                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
349                         << "{\n"
350                         << "  uint result[];\n"
351                         << "};\n"
352                         << "\n"
353                         << "void main (void)\n"
354                         << "{\n"
355                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
356                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
357                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
358                         "gl_GlobalInvocationID.x;\n"
359                         << bdyStr
360                         << "  result[offset] = tempResult;\n"
361                         << "}\n";
362
363                 programCollection.glslSources.add("comp")
364                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
365         }
366         else
367         {
368                 const string vertex =
369                         "#version 450\n"
370                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
371                         "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
372                         "{\n"
373                         "  uint result[];\n"
374                         "};\n"
375                         "\n"
376                         "void main (void)\n"
377                         "{\n"
378                         + bdyStr +
379                         "  result[gl_VertexIndex] = tempResult;\n"
380                         "  float pixelSize = 2.0f/1024.0f;\n"
381                         "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
382                         "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
383                         "  gl_PointSize = 1.0f;\n"
384                         "}\n";
385
386                 const string tesc =
387                         "#version 450\n"
388                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
389                         "layout(vertices=1) out;\n"
390                         "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
391                         "{\n"
392                         "  uint result[];\n"
393                         "};\n"
394                         "\n"
395                         "void main (void)\n"
396                         "{\n"
397                         + bdyStr +
398                         "  result[gl_PrimitiveID] = tempResult;\n"
399                         "  if (gl_InvocationID == 0)\n"
400                         "  {\n"
401                         "    gl_TessLevelOuter[0] = 1.0f;\n"
402                         "    gl_TessLevelOuter[1] = 1.0f;\n"
403                         "  }\n"
404                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
405                         "}\n";
406
407                 const string tese =
408                         "#version 450\n"
409                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
410                         "layout(isolines) in;\n"
411                         "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
412                         "{\n"
413                         "  uint result[];\n"
414                         "};\n"
415                         "\n"
416                         "void main (void)\n"
417                         "{\n"
418                         + bdyStr +
419                         "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
420                         "  float pixelSize = 2.0f/1024.0f;\n"
421                         "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
422                         "}\n";
423
424                 const string geometry =
425                         "#version 450\n"
426                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
427                         "layout(${TOPOLOGY}) in;\n"
428                         "layout(points, max_vertices = 1) out;\n"
429                         "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
430                         "{\n"
431                         "  uint result[];\n"
432                         "};\n"
433                         "\n"
434                         "void main (void)\n"
435                         "{\n"
436                         + bdyStr +
437                         "  result[gl_PrimitiveIDIn] = tempResult;\n"
438                         "  gl_Position = gl_in[0].gl_Position;\n"
439                         "  EmitVertex();\n"
440                         "  EndPrimitive();\n"
441                         "}\n";
442
443                 const string fragment =
444                         "#version 450\n"
445                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
446                         "layout(location = 0) out uint result;\n"
447                         "void main (void)\n"
448                         "{\n"
449                         + bdyStr +
450                         "  result = tempResult;\n"
451                         "}\n";
452
453                 subgroups::addNoSubgroupShader(programCollection);
454
455                 programCollection.glslSources.add("vert")
456                                 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
457                 programCollection.glslSources.add("tesc")
458                                 << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
459                 programCollection.glslSources.add("tese")
460                                 << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
461                 subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
462                                                                                                   programCollection.glslSources);
463                 programCollection.glslSources.add("fragment")
464                                 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
465         }
466 }
467
468 void supportedCheck (Context& context, CaseDefinition caseDef)
469 {
470         DE_UNREF(caseDef);
471         if (!subgroups::isSubgroupSupported(context))
472                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
473
474         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
475         {
476                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
477         }
478 }
479
480 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
481 {
482         if (!subgroups::areSubgroupOperationsSupportedForStage(
483                         context, caseDef.shaderStage))
484         {
485                 if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
486                 {
487                         return tcu::TestStatus::fail(
488                                            "Shader stage " +
489                                            subgroups::getShaderStageName(caseDef.shaderStage) +
490                                            " is required to support subgroup operations!");
491                 }
492                 else
493                 {
494                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
495                 }
496         }
497
498         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
499                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
500         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
501                 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
502         else if ((VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) & caseDef.shaderStage)
503                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
504         else
505                 TCU_THROW(InternalError, "Unhandled shader stage");
506 }
507
508 tcu::TestStatus test (Context& context, const CaseDefinition caseDef)
509 {
510         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
511         {
512                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
513                 {
514                         return tcu::TestStatus::fail(
515                                            "Shader stage " +
516                                 subgroups::getShaderStageName(caseDef.shaderStage) +
517                                 " is required to support subgroup operations!");
518                 }
519                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkCompute);
520         }
521         else
522         {
523                 VkPhysicalDeviceSubgroupProperties subgroupProperties;
524                 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
525                 subgroupProperties.pNext = DE_NULL;
526
527                 VkPhysicalDeviceProperties2 properties;
528                 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
529                 properties.pNext = &subgroupProperties;
530
531                 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
532
533                 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
534
535                 if ( VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
536                 {
537                         if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
538                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
539                         else
540                                 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
541                 }
542
543                 if ((VkShaderStageFlagBits)0u == stages)
544                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
545
546                 return subgroups::allStages(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, stages);
547         }
548         return tcu::TestStatus::pass("OK");
549 }
550 }
551
552 namespace vkt
553 {
554 namespace subgroups
555 {
556 tcu::TestCaseGroup* createSubgroupsBallotOtherTests(tcu::TestContext& testCtx)
557 {
558         de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
559                 testCtx, "graphics", "Subgroup ballot other category tests: graphics"));
560         de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
561                 testCtx, "compute", "Subgroup ballot other category tests: compute"));
562         de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
563                 testCtx, "framebuffer", "Subgroup ballot other category tests: framebuffer"));
564
565         const VkShaderStageFlags stages[] =
566         {
567                 VK_SHADER_STAGE_VERTEX_BIT,
568                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
569                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
570                 VK_SHADER_STAGE_GEOMETRY_BIT,
571         };
572
573         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
574         {
575                 const string    op              = de::toLower(getOpTypeName(opTypeIndex));
576                 {
577                         const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT};
578                         addFunctionCaseWithPrograms(computeGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
579                 }
580
581                 {
582                         const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS};
583                         addFunctionCaseWithPrograms(graphicGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
584                 }
585
586                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
587                 {
588                         const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex]};
589                         addFunctionCaseWithPrograms(framebufferGroup.get(), op + "_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
590                 }
591         }
592
593         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
594                 testCtx, "ballot_other", "Subgroup ballot other category tests"));
595
596         group->addChild(graphicGroup.release());
597         group->addChild(computeGroup.release());
598         group->addChild(framebufferGroup.release());
599
600         return group.release();
601 }
602
603 } // subgroups
604 } // vkt