e32862d1e6fc24a383ed8c0ce45ff66eada7bfdf
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsShapeTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsShapeTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 static bool checkVertexPipelineStages(std::vector<const void*> datas,
39                                                                           deUint32 width, deUint32)
40 {
41         return vkt::subgroups::check(datas, width, 1);
42 }
43
44 static bool checkCompute(std::vector<const void*> datas,
45                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
46                                                  deUint32)
47 {
48         return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
49 }
50
51 enum OpType
52 {
53         OPTYPE_CLUSTERED = 0,
54         OPTYPE_QUAD,
55         OPTYPE_LAST
56 };
57
58 std::string getOpTypeName(int opType)
59 {
60         switch (opType)
61         {
62                 default:
63                         DE_FATAL("Unsupported op type");
64                         return "";
65                 case OPTYPE_CLUSTERED:
66                         return "clustered";
67                 case OPTYPE_QUAD:
68                         return "quad";
69         }
70 }
71
72 struct CaseDefinition
73 {
74         int                                     opType;
75         VkShaderStageFlags      shaderStage;
76 };
77
78 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
79 {
80         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
81         std::ostringstream                              bdy;
82         std::string                                             extension = (OPTYPE_CLUSTERED == caseDef.opType) ?
83                                                                                 "#extension GL_KHR_shader_subgroup_clustered: enable\n" :
84                                                                                 "#extension GL_KHR_shader_subgroup_quad: enable\n";
85
86         subgroups::setFragmentShaderFrameBuffer(programCollection);
87
88         if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
89                 subgroups::setVertexShaderFrameBuffer(programCollection);
90
91         extension += "#extension GL_KHR_shader_subgroup_ballot: enable\n";
92
93         bdy << "  uint tempResult = 0x1;\n"
94                 << "  uvec4 mask = subgroupBallot(true);\n";
95
96         if (OPTYPE_CLUSTERED == caseDef.opType)
97         {
98                 for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
99                 {
100                         bdy << "  if (gl_SubgroupSize >= " << i << ")\n"
101                                 << "  {\n"
102                                 << "    uvec4 contribution = uvec4(0);\n"
103                                 << "    const uint modID = gl_SubgroupInvocationID % 32;\n"
104                                 << "    switch (gl_SubgroupInvocationID / 32)\n"
105                                 << "    {\n"
106                                 << "    case 0: contribution.x = 1 << modID; break;\n"
107                                 << "    case 1: contribution.y = 1 << modID; break;\n"
108                                 << "    case 2: contribution.z = 1 << modID; break;\n"
109                                 << "    case 3: contribution.w = 1 << modID; break;\n"
110                                 << "    }\n"
111                                 << "    uvec4 result = subgroupClusteredOr(contribution, " << i << ");\n"
112                                 << "    uint rootID = gl_SubgroupInvocationID & ~(" << i - 1 << ");\n"
113                                 << "    for (uint i = 0; i < " << i << "; i++)\n"
114                                 << "    {\n"
115                                 << "      uint nextID = rootID + i;\n"
116                                 << "      if (subgroupBallotBitExtract(mask, nextID) ^^ subgroupBallotBitExtract(result, nextID))\n"
117                                 << "      {\n"
118                                 << "        tempResult = 0;\n"
119                                 << "      }\n"
120                                 << "    }\n"
121                                 << "  }\n";
122                 }
123         }
124         else
125         {
126                 bdy << "  uint cluster[4] =\n"
127                         << "  {\n"
128                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 0),\n"
129                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 1),\n"
130                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 2),\n"
131                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 3)\n"
132                         << "  };\n"
133                         << "  uint rootID = gl_SubgroupInvocationID & ~0x3;\n"
134                         << "  for (uint i = 0; i < 4; i++)\n"
135                         << "  {\n"
136                         << "    uint nextID = rootID + i;\n"
137                         << "    if (subgroupBallotBitExtract(mask, nextID) && (cluster[i] != nextID))\n"
138                         << "    {\n"
139                         << "      tempResult = mask.x;\n"
140                         << "    }\n"
141                         << "  }\n";
142         }
143
144         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
145         {
146                 std::ostringstream vertexSrc;
147                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
148                         << extension
149                         << "layout(location = 0) in highp vec4 in_position;\n"
150                         << "layout(location = 0) out float result;\n"
151                         << "\n"
152                         << "void main (void)\n"
153                         << "{\n"
154                         << bdy.str()
155                         << "  result = float(tempResult);\n"
156                         << "  gl_Position = in_position;\n"
157                         << "  gl_PointSize = 1.0f;\n"
158                         << "}\n";
159                 programCollection.glslSources.add("vert")
160                         << glu::VertexSource(vertexSrc.str()) << buildOptions;
161         }
162         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
163         {
164                 std::ostringstream geometry;
165
166                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
167                         << extension
168                         << "layout(points) in;\n"
169                         << "layout(points, max_vertices = 1) out;\n"
170                         << "layout(location = 0) out float out_color;\n"
171                         << "\n"
172                         << "void main (void)\n"
173                         << "{\n"
174                         << bdy.str()
175                         << "  out_color = float(tempResult);\n"
176                         << "  gl_Position = gl_in[0].gl_Position;\n"
177                         << "  EmitVertex();\n"
178                         << "  EndPrimitive();\n"
179                         << "}\n";
180
181                 programCollection.glslSources.add("geometry")
182                         << glu::GeometrySource(geometry.str()) << buildOptions;
183         }
184         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
185         {
186                 std::ostringstream controlSource;
187
188                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
189                         << extension
190                         << "layout(vertices = 2) out;\n"
191                         << "layout(location = 0) out float out_color[];\n"
192                         << "\n"
193                         << "void main (void)\n"
194                         << "{\n"
195                         << "  if (gl_InvocationID == 0)\n"
196                         <<"  {\n"
197                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
198                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
199                         << "  }\n"
200                         << bdy.str()
201                         << "  out_color[gl_InvocationID] = float(tempResult);\n"
202                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
203                         << "}\n";
204
205                 programCollection.glslSources.add("tesc")
206                         << glu::TessellationControlSource(controlSource.str()) << buildOptions;
207                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
208         }
209         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
210         {
211                 std::ostringstream evaluationSource;
212
213                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
214                         << extension
215                         << "layout(isolines, equal_spacing, ccw ) in;\n"
216                         << "layout(location = 0) out float out_color;\n"
217                         << "void main (void)\n"
218                         << "{\n"
219                         << bdy.str()
220                         << "  out_color = float(tempResult);\n"
221                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
222                         << "}\n";
223
224                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
225                 programCollection.glslSources.add("tese")
226                                 << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
227         }
228         else
229         {
230                 DE_FATAL("Unsupported shader stage");
231         }
232 }
233
234 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
235 {
236         std::string extension = (OPTYPE_CLUSTERED == caseDef.opType) ?
237                                                         "#extension GL_KHR_shader_subgroup_clustered: enable\n" :
238                                                         "#extension GL_KHR_shader_subgroup_quad: enable\n";
239
240         extension += "#extension GL_KHR_shader_subgroup_ballot: enable\n";
241
242         std::ostringstream bdy;
243
244         bdy << "  uint tempResult = 0x1;\n"
245                 << "  uvec4 mask = subgroupBallot(true);\n";
246
247         if (OPTYPE_CLUSTERED == caseDef.opType)
248         {
249                 for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
250                 {
251                         bdy << "  if (gl_SubgroupSize >= " << i << ")\n"
252                                 << "  {\n"
253                                 << "    uvec4 contribution = uvec4(0);\n"
254                                 << "    const uint modID = gl_SubgroupInvocationID % 32;\n"
255                                 << "    switch (gl_SubgroupInvocationID / 32)\n"
256                                 << "    {\n"
257                                 << "    case 0: contribution.x = 1 << modID; break;\n"
258                                 << "    case 1: contribution.y = 1 << modID; break;\n"
259                                 << "    case 2: contribution.z = 1 << modID; break;\n"
260                                 << "    case 3: contribution.w = 1 << modID; break;\n"
261                                 << "    }\n"
262                                 << "    uvec4 result = subgroupClusteredOr(contribution, " << i << ");\n"
263                                 << "    uint rootID = gl_SubgroupInvocationID & ~(" << i - 1 << ");\n"
264                                 << "    for (uint i = 0; i < " << i << "; i++)\n"
265                                 << "    {\n"
266                                 << "      uint nextID = rootID + i;\n"
267                                 << "      if (subgroupBallotBitExtract(mask, nextID) ^^ subgroupBallotBitExtract(result, nextID))\n"
268                                 << "      {\n"
269                                 << "        tempResult = 0;\n"
270                                 << "      }\n"
271                                 << "    }\n"
272                                 << "  }\n";
273                 }
274         }
275         else
276         {
277                 bdy << "  uint cluster[4] =\n"
278                         << "  {\n"
279                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 0),\n"
280                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 1),\n"
281                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 2),\n"
282                         << "    subgroupQuadBroadcast(gl_SubgroupInvocationID, 3)\n"
283                         << "  };\n"
284                         << "  uint rootID = gl_SubgroupInvocationID & ~0x3;\n"
285                         << "  for (uint i = 0; i < 4; i++)\n"
286                         << "  {\n"
287                         << "    uint nextID = rootID + i;\n"
288                         << "    if (subgroupBallotBitExtract(mask, nextID) && (cluster[i] != nextID))\n"
289                         << "    {\n"
290                         << "      tempResult = mask.x;\n"
291                         << "    }\n"
292                         << "  }\n";
293         }
294
295         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
296         {
297                 std::ostringstream src;
298
299                 src << "#version 450\n"
300                         << extension
301                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
302                         "local_size_z_id = 2) in;\n"
303                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
304                         << "{\n"
305                         << "  uint result[];\n"
306                         << "};\n"
307                         << "\n"
308                         << "void main (void)\n"
309                         << "{\n"
310                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
311                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
312                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
313                         "gl_GlobalInvocationID.x;\n"
314                         << bdy.str()
315                         << "  result[offset] = tempResult;\n"
316                         << "}\n";
317
318                 programCollection.glslSources.add("comp")
319                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
320         }
321         else
322         {
323                 {
324                         const string vertex =
325                                 "#version 450\n"
326                                 + extension +
327                                 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
328                                 "{\n"
329                                 "  uint result[];\n"
330                                 "};\n"
331                                 "\n"
332                                 "void main (void)\n"
333                                 "{\n"
334                                 + bdy.str() +
335                                 "  result[gl_VertexIndex] = tempResult;\n"
336                                 "  float pixelSize = 2.0f/1024.0f;\n"
337                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
338                                 "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
339                                 "}\n";
340
341                         programCollection.glslSources.add("vert")
342                                 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
343                 }
344
345                 {
346                         const string tesc =
347                                 "#version 450\n"
348                                 + extension +
349                                 "layout(vertices=1) out;\n"
350                                 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
351                                 "{\n"
352                                 "  uint result[];\n"
353                                 "};\n"
354                                 "\n"
355                                 "void main (void)\n"
356                                 "{\n"
357                                 + bdy.str() +
358                                 "  result[gl_PrimitiveID] = 1;\n"
359                                 "  if (gl_InvocationID == 0)\n"
360                                 "  {\n"
361                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
362                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
363                                 "  }\n"
364                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
365                                 "}\n";
366
367                         programCollection.glslSources.add("tesc")
368                                         << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
369                 }
370
371                 {
372                         const string tese =
373                                 "#version 450\n"
374                                 + extension +
375                                 "layout(isolines) in;\n"
376                                 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
377                                 "{\n"
378                                 "  uint result[];\n"
379                                 "};\n"
380                                 "\n"
381                                 "void main (void)\n"
382                                 "{\n"
383                                 + bdy.str() +
384                                 "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1;\n"
385                                 "  float pixelSize = 2.0f/1024.0f;\n"
386                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
387                                 "}\n";
388
389                         programCollection.glslSources.add("tese")
390                                         << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
391                 }
392
393                 {
394                         const string geometry =
395                                 "#version 450\n"
396                                 + extension +
397                                 "layout(${TOPOLOGY}) in;\n"
398                                 "layout(points, max_vertices = 1) out;\n"
399                                 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
400                                 "{\n"
401                                 "  uint result[];\n"
402                                 "};\n"
403                                 "\n"
404                                 "void main (void)\n"
405                                 "{\n"
406                                 + bdy.str() +
407                                 "  result[gl_PrimitiveIDIn] = tempResult;\n"
408                                 "  gl_Position = gl_in[0].gl_Position;\n"
409                                 "  EmitVertex();\n"
410                                 "  EndPrimitive();\n"
411                                 "}\n";
412
413                         subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
414                                                                                                           programCollection.glslSources);
415                 }
416
417                 {
418                         const string fragment =
419                                 "#version 450\n"
420                                 + extension +
421                                 "layout(location = 0) out uint result;\n"
422                                 "void main (void)\n"
423                                 "{\n"
424                                 + bdy.str() +
425                                 "  result = tempResult;\n"
426                                 "}\n";
427
428                         programCollection.glslSources.add("fragment")
429                                 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
430                 }
431                 subgroups::addNoSubgroupShader(programCollection);
432         }
433 }
434
435 void supportedCheck (Context& context, CaseDefinition caseDef)
436 {
437         if (!subgroups::isSubgroupSupported(context))
438                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
439
440         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
441         {
442                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
443         }
444
445         if (OPTYPE_CLUSTERED == caseDef.opType)
446         {
447                 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_CLUSTERED_BIT))
448                 {
449                         TCU_THROW(NotSupportedError, "Subgroup shape tests require that clustered operations are supported!");
450                 }
451         }
452
453         if (OPTYPE_QUAD == caseDef.opType)
454         {
455                 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
456                 {
457                         TCU_THROW(NotSupportedError, "Subgroup shape tests require that quad operations are supported!");
458                 }
459         }
460 }
461
462 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
463 {
464         if (!subgroups::areSubgroupOperationsSupportedForStage(
465                                 context, caseDef.shaderStage))
466         {
467                 if (subgroups::areSubgroupOperationsRequiredForStage(
468                                         caseDef.shaderStage))
469                 {
470                         return tcu::TestStatus::fail(
471                                            "Shader stage " +
472                                            subgroups::getShaderStageName(caseDef.shaderStage) +
473                                            " is required to support subgroup operations!");
474                 }
475                 else
476                 {
477                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
478                 }
479         }
480
481         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
482                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
483         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
484                 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
485         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
486                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
487         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
488                 return subgroups::makeTessellationEvaluationFrameBufferTest(context,  VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
489         else
490                 TCU_THROW(InternalError, "Unhandled shader stage");
491 }
492
493
494 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
495 {
496         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BASIC_BIT))
497         {
498                 return tcu::TestStatus::fail(
499                                    "Subgroup feature " +
500                                    subgroups::getShaderStageName(VK_SUBGROUP_FEATURE_BASIC_BIT) +
501                                    " is a required capability!");
502         }
503
504         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
505         {
506                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
507                 {
508                         return tcu::TestStatus::fail(
509                                            "Shader stage " +
510                                            subgroups::getShaderStageName(caseDef.shaderStage) +
511                                            " is required to support subgroup operations!");
512                 }
513                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkCompute);
514         }
515         else
516         {
517                 VkPhysicalDeviceSubgroupProperties subgroupProperties;
518                 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
519                 subgroupProperties.pNext = DE_NULL;
520
521                 VkPhysicalDeviceProperties2 properties;
522                 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
523                 properties.pNext = &subgroupProperties;
524
525                 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
526
527                 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
528
529                 if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
530                 {
531                         if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
532                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
533                         else
534                                 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
535                 }
536
537                 if ((VkShaderStageFlagBits)0u == stages)
538                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
539
540                 return subgroups::allStages(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, stages);
541         }
542 }
543 }
544
545 namespace vkt
546 {
547 namespace subgroups
548 {
549 tcu::TestCaseGroup* createSubgroupsShapeTests(tcu::TestContext& testCtx)
550 {
551         de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
552                 testCtx, "graphics", "Subgroup shape category tests: graphics"));
553         de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
554                 testCtx, "compute", "Subgroup shape category tests: compute"));
555         de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
556                 testCtx, "framebuffer", "Subgroup shape category tests: framebuffer"));
557
558         const VkShaderStageFlags stages[] =
559         {
560                 VK_SHADER_STAGE_VERTEX_BIT,
561                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
562                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
563                 VK_SHADER_STAGE_GEOMETRY_BIT,
564         };
565
566         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
567         {
568                 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
569
570                 {
571                         const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT};
572                         addFunctionCaseWithPrograms(computeGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
573
574                 }
575
576                 {
577                         const CaseDefinition caseDef =
578                         {
579                                 opTypeIndex,
580                                 VK_SHADER_STAGE_ALL_GRAPHICS
581                         };
582                         addFunctionCaseWithPrograms(graphicGroup.get(),
583                                                                         op, "",
584                                                                         supportedCheck, initPrograms, test, caseDef);
585                 }
586
587                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
588                 {
589                         const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex]};
590                         addFunctionCaseWithPrograms(framebufferGroup.get(),op + "_" + getShaderStageName(caseDef.shaderStage), "",
591                                                                                 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
592                 }
593         }
594
595         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
596                 testCtx, "shape", "Subgroup shape category tests"));
597
598         group->addChild(graphicGroup.release());
599         group->addChild(computeGroup.release());
600         group->addChild(framebufferGroup.release());
601
602         return group.release();
603 }
604
605 } // subgroups
606 } // vkt