Merge remote-tracking branch 'goog/upstream-vulkan-cts-next' into vulkan-cts-1.1...
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsBallotBroadcastTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBallotBroadcastTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_BROADCAST = 0,
41         OPTYPE_BROADCAST_FIRST,
42         OPTYPE_LAST
43 };
44
45 static bool checkVertexPipelineStages(std::vector<const void*> datas,
46                                                                           deUint32 width, deUint32)
47 {
48         const deUint32* data =
49                 reinterpret_cast<const deUint32*>(datas[0]);
50         for (deUint32 x = 0; x < width; ++x)
51         {
52                 deUint32 val = data[x];
53
54                 if (0x3 != val)
55                 {
56                         return false;
57                 }
58         }
59
60         return true;
61 }
62
63 static bool checkFragment(std::vector<const void*> datas,
64                                                   deUint32 width, deUint32 height, deUint32)
65 {
66         const deUint32* data =
67                 reinterpret_cast<const deUint32*>(datas[0]);
68         for (deUint32 x = 0; x < width; ++x)
69         {
70                 for (deUint32 y = 0; y < height; ++y)
71                 {
72                         deUint32 val = data[x * height + y];
73
74                         if (0x3 != val)
75                         {
76                                 return false;
77                         }
78                 }
79         }
80
81         return true;
82 }
83
84 static bool checkCompute(std::vector<const void*> datas,
85                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
86                                                  deUint32)
87 {
88         const deUint32* data =
89                 reinterpret_cast<const deUint32*>(datas[0]);
90
91         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
92         {
93                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
94                 {
95                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
96                         {
97                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
98                                 {
99                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
100                                         {
101                                                 for (deUint32 lZ = 0; lZ < localSize[2];
102                                                                 ++lZ)
103                                                 {
104                                                         const deUint32 globalInvocationX =
105                                                                 nX * localSize[0] + lX;
106                                                         const deUint32 globalInvocationY =
107                                                                 nY * localSize[1] + lY;
108                                                         const deUint32 globalInvocationZ =
109                                                                 nZ * localSize[2] + lZ;
110
111                                                         const deUint32 globalSizeX =
112                                                                 numWorkgroups[0] * localSize[0];
113                                                         const deUint32 globalSizeY =
114                                                                 numWorkgroups[1] * localSize[1];
115
116                                                         const deUint32 offset =
117                                                                 globalSizeX *
118                                                                 ((globalSizeY *
119                                                                   globalInvocationZ) +
120                                                                  globalInvocationY) +
121                                                                 globalInvocationX;
122
123                                                         if (0x3 != data[offset])
124                                                         {
125                                                                 return false;
126                                                         }
127                                                 }
128                                         }
129                                 }
130                         }
131                 }
132         }
133
134         return true;
135 }
136
137
138 std::string getOpTypeName(int opType)
139 {
140         switch (opType)
141         {
142                 default:
143                         DE_FATAL("Unsupported op type");
144                 case OPTYPE_BROADCAST:
145                         return "subgroupBroadcast";
146                 case OPTYPE_BROADCAST_FIRST:
147                         return "subgroupBroadcastFirst";
148         }
149 }
150
151
152 struct CaseDefinition
153 {
154         int                                     opType;
155         VkShaderStageFlags      shaderStage;
156         VkFormat                        format;
157         bool                            noSSBO;
158 };
159
160 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
161 {
162         std::ostringstream bdy;
163
164         bdy << "  uint tempResult = 0;\n";
165
166         if (OPTYPE_BROADCAST == caseDef.opType)
167         {
168                 bdy << "  tempResult = 0x3;\n";
169
170                 for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
171                 {
172                         bdy     << "  {\n"
173                                 << "    const uint id = " << i << ";\n"
174                                 << "    " << subgroups::getFormatNameForGLSL(caseDef.format)
175                                 << " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
176                                 << "    if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
177                                 << "    {\n"
178                                 << "      if (op != data1[id])\n"
179                                 << "      {\n"
180                                 << "        tempResult = 0;\n"
181                                 << "      }\n"
182                                 << "    }\n"
183                                 << "  }\n";
184                 }
185         }
186         else
187         {
188                 bdy     << "  uint firstActive = 0;\n"
189                         << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
190                         << "  {\n"
191                         << "    if (subgroupBallotBitExtract(mask, i))\n"
192                         << "    {\n"
193                         << "      firstActive = i;\n"
194                         << "      break;\n"
195                         << "    }\n"
196                         << "  }\n"
197                         << "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
198                         << "  // make the firstActive invocation inactive now\n"
199                         << "  if (firstActive == gl_SubgroupInvocationID)\n"
200                         << "  {\n"
201                         << "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
202                         << "    {\n"
203                         << "      if (subgroupBallotBitExtract(mask, i))\n"
204                         << "      {\n"
205                         << "        firstActive = i;\n"
206                         << "        break;\n"
207                         << "      }\n"
208                         << "    }\n"
209                         << "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
210                         << "  }\n"
211                         << "  else\n"
212                         << "  {\n"
213                         << "    // the firstActive invocation didn't partake in the second result so set it to true\n"
214                         << "    tempResult |= 0x2;\n"
215                         << "  }\n";
216         }
217
218         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
219         {
220                 std::ostringstream src;
221                 std::ostringstream      fragmentSrc;
222
223                 src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
224                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
225                         << "layout(location = 0) in highp vec4 in_position;\n"
226                         << "layout(location = 0) out float out_color;\n"
227                         << "layout(set = 0, binding = 0) uniform  Buffer1\n"
228                         << "{\n"
229                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
230                         << "};\n"
231                         << "\n"
232                         << "void main (void)\n"
233                         << "{\n"
234                         << "  uvec4 mask = subgroupBallot(true);\n"
235                         << bdy.str()
236                         << "  out_color = float(tempResult);\n"
237                         << "  gl_Position = in_position;\n"
238                         << "  gl_PointSize = 1.0f;\n"
239                         << "}\n";
240
241                 programCollection.glslSources.add("vert")
242                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
243
244                 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
245                         << "layout(location = 0) in float in_color;\n"
246                         << "layout(location = 0) out uint out_color;\n"
247                         << "void main()\n"
248                         <<"{\n"
249                         << "    out_color = uint(in_color);\n"
250                         << "}\n";
251                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
252         }
253         else
254         {
255                 DE_FATAL("Unsupported shader stage");
256         }
257 }
258
259 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
260 {
261         std::ostringstream bdy;
262
263         bdy << "  uint tempResult = 0;\n";
264
265         if (OPTYPE_BROADCAST == caseDef.opType)
266         {
267                 bdy << "  tempResult = 0x3;\n";
268
269                 for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
270                 {
271                         bdy     << "  {\n"
272                                 << "    const uint id = " << i << ";\n"
273                                 << "    " << subgroups::getFormatNameForGLSL(caseDef.format)
274                                 << " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
275                                 << "    if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
276                                 << "    {\n"
277                                 << "      if (op != data1[id])\n"
278                                 << "      {\n"
279                                 << "        tempResult = 0;\n"
280                                 << "      }\n"
281                                 << "    }\n"
282                                 << "  }\n";
283                 }
284         }
285         else
286         {
287                 bdy     << "  uint firstActive = 0;\n"
288                         << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
289                         << "  {\n"
290                         << "    if (subgroupBallotBitExtract(mask, i))\n"
291                         << "    {\n"
292                         << "      firstActive = i;\n"
293                         << "      break;\n"
294                         << "    }\n"
295                         << "  }\n"
296                         << "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
297                         << "  // make the firstActive invocation inactive now\n"
298                         << "  if (firstActive == gl_SubgroupInvocationID)\n"
299                         << "  {\n"
300                         << "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
301                         << "    {\n"
302                         << "      if (subgroupBallotBitExtract(mask, i))\n"
303                         << "      {\n"
304                         << "        firstActive = i;\n"
305                         << "        break;\n"
306                         << "      }\n"
307                         << "    }\n"
308                         << "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
309                         << "  }\n"
310                         << "  else\n"
311                         << "  {\n"
312                         << "    // the firstActive invocation didn't partake in the second result so set it to true\n"
313                         << "    tempResult |= 0x2;\n"
314                         << "  }\n";
315         }
316
317         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
318         {
319                 std::ostringstream src;
320
321                 src << "#version 450\n"
322                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
323                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
324                         "local_size_z_id = 2) in;\n"
325                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
326                         << "{\n"
327                         << "  uint result[];\n"
328                         << "};\n"
329                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
330                         << "{\n"
331                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
332                         << "};\n"
333                         << "\n"
334                         << "void main (void)\n"
335                         << "{\n"
336                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
337                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
338                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
339                         "gl_GlobalInvocationID.x;\n"
340                         << "  uvec4 mask = subgroupBallot(true);\n"
341                         << bdy.str()
342                         << "  result[offset] = tempResult;\n"
343                         << "}\n";
344
345                 programCollection.glslSources.add("comp")
346                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
347         }
348         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
349         {
350                 programCollection.glslSources.add("vert")
351                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
352
353                 std::ostringstream frag;
354
355                 frag << "#version 450\n"
356                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
357                          << "layout(location = 0) out uint result;\n"
358                          << "layout(set = 0, binding = 0, std430) readonly buffer Buffer1\n"
359                          << "{\n"
360                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
361                          << "};\n"
362                          << "void main (void)\n"
363                          << "{\n"
364                          << "  uvec4 mask = subgroupBallot(true);\n"
365                          << bdy.str()
366                          << "  result = tempResult;\n"
367                          << "}\n";
368
369                 programCollection.glslSources.add("frag")
370                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
371         }
372         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
373         {
374                 std::ostringstream src;
375
376                 src << "#version 450\n"
377                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
378                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
379                         << "{\n"
380                         << "  uint result[];\n"
381                         << "};\n"
382                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
383                         << "{\n"
384                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
385                         << "};\n"
386                         << "\n"
387                         << "void main (void)\n"
388                         << "{\n"
389                         << "  uvec4 mask = subgroupBallot(true);\n"
390                         << bdy.str()
391                         << "  result[gl_VertexIndex] = tempResult;\n"
392                         << "  gl_PointSize = 1.0f;\n"
393                         << "}\n";
394
395                 programCollection.glslSources.add("vert")
396                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
397         }
398         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
399         {
400                 programCollection.glslSources.add("vert")
401                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
402
403                 std::ostringstream src;
404
405                 src << "#version 450\n"
406                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
407                         << "layout(points) in;\n"
408                         << "layout(points, max_vertices = 1) out;\n"
409                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
410                         << "{\n"
411                         << "  uint result[];\n"
412                         << "};\n"
413                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
414                         << "{\n"
415                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
416                         << "};\n"
417                         << "\n"
418                         << "void main (void)\n"
419                         << "{\n"
420                         << "  uvec4 mask = subgroupBallot(true);\n"
421                         << bdy.str()
422                         << "  result[gl_PrimitiveIDIn] = tempResult;\n"
423                         << "}\n";
424
425                 programCollection.glslSources.add("geom")
426                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
427         }
428         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
429         {
430                 programCollection.glslSources.add("vert")
431                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
432
433                 programCollection.glslSources.add("tese")
434                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
435
436                 std::ostringstream src;
437
438                 src << "#version 450\n"
439                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
440                         << "layout(vertices=1) out;\n"
441                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
442                         << "{\n"
443                         << "  uint result[];\n"
444                         << "};\n"
445                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
446                         << "{\n"
447                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
448                         << "};\n"
449                         << "\n"
450                         << "void main (void)\n"
451                         << "{\n"
452                         << "  uvec4 mask = subgroupBallot(true);\n"
453                         << bdy.str()
454                         << "  result[gl_PrimitiveID] = tempResult;\n"
455                         << "}\n";
456
457                 programCollection.glslSources.add("tesc")
458                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
459         }
460         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
461         {
462                 programCollection.glslSources.add("vert")
463                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
464
465                 programCollection.glslSources.add("tesc")
466                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
467
468                 std::ostringstream src;
469
470                 src << "#version 450\n"
471                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
472                         << "layout(isolines) in;\n"
473                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
474                         << "{\n"
475                         << "  uint result[];\n"
476                         << "};\n"
477                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
478                         << "{\n"
479                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
480                         << "};\n"
481                         << "\n"
482                         << "void main (void)\n"
483                         << "{\n"
484                         << "  uvec4 mask = subgroupBallot(true);\n"
485                         << bdy.str()
486                         << "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
487                         << "}\n";
488
489                 programCollection.glslSources.add("tese")
490                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
491         }
492         else
493         {
494                 DE_FATAL("Unsupported shader stage");
495         }
496 }
497
498 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
499 {
500         if (!subgroups::isSubgroupSupported(context))
501                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
502
503         if (!subgroups::areSubgroupOperationsSupportedForStage(
504                                 context, caseDef.shaderStage))
505         {
506                 if (subgroups::areSubgroupOperationsRequiredForStage(
507                                         caseDef.shaderStage))
508                 {
509                         return tcu::TestStatus::fail(
510                                            "Shader stage " +
511                                            subgroups::getShaderStageName(caseDef.shaderStage) +
512                                            " is required to support subgroup operations!");
513                 }
514                 else
515                 {
516                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
517                 }
518         }
519
520         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
521         {
522                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
523         }
524
525         if (subgroups::isDoubleFormat(caseDef.format) &&
526                         !subgroups::isDoubleSupportedForDevice(context))
527         {
528                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
529         }
530
531         //Tests which don't use the SSBO
532         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
533         {
534                 subgroups::SSBOData inputData[1];
535                 inputData[0].format = caseDef.format;
536                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
537                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
538
539                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages);
540         }
541
542         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
543                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
544         {
545                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
546                 {
547                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
548                 }
549         }
550
551         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
552         {
553                 subgroups::SSBOData inputData[1];
554                 inputData[0].format = caseDef.format;
555                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
556                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
557
558                 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
559                                                                                    inputData, 1, checkFragment);
560         }
561         else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
562         {
563                 subgroups::SSBOData inputData[1];
564                 inputData[0].format = caseDef.format;
565                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
566                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
567
568                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
569                                                                                   inputData, 1, checkCompute);
570         }
571         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
572         {
573                 subgroups::SSBOData inputData[1];
574                 inputData[0].format = caseDef.format;
575                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
576                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
577
578                 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
579                                                                                  inputData, 1, checkVertexPipelineStages);
580         }
581         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
582         {
583                 subgroups::SSBOData inputData[1];
584                 inputData[0].format = caseDef.format;
585                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
586                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
587
588                 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
589                                                                                    inputData, 1, checkVertexPipelineStages);
590         }
591         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
592         {
593                 subgroups::SSBOData inputData[1];
594                 inputData[0].format = caseDef.format;
595                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
596                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
597
598                 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
599                                 inputData, 1, checkVertexPipelineStages);
600         }
601         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
602         {
603                 subgroups::SSBOData inputData[1];
604                 inputData[0].format = caseDef.format;
605                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
606                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
607
608                 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
609                                 inputData, 1, checkVertexPipelineStages);
610         }
611         else
612         {
613                 TCU_THROW(InternalError, "Unhandled shader stage");
614         }
615 }
616 }
617
618 namespace vkt
619 {
620 namespace subgroups
621 {
622 tcu::TestCaseGroup* createSubgroupsBallotBroadcastTests(tcu::TestContext& testCtx)
623 {
624         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
625                         testCtx, "ballot_broadcast", "Subgroup ballot broadcast category tests"));
626
627         const VkShaderStageFlags stages[] =
628         {
629                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
630                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
631                 VK_SHADER_STAGE_GEOMETRY_BIT,
632                 VK_SHADER_STAGE_VERTEX_BIT,
633                 VK_SHADER_STAGE_FRAGMENT_BIT,
634                 VK_SHADER_STAGE_COMPUTE_BIT
635         };
636
637         const VkFormat formats[] =
638         {
639                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
640                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
641                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
642                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
643                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
644                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
645                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
646                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
647                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
648         };
649
650         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
651         {
652                 const VkShaderStageFlags stage = stages[stageIndex];
653
654                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
655                 {
656                         const VkFormat format = formats[formatIndex];
657
658                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
659                         {
660                                 CaseDefinition caseDef = {opTypeIndex, stage, format, false};
661
662                                 std::ostringstream name;
663
664                                 std::string op = getOpTypeName(opTypeIndex);
665
666                                 name << de::toLower(op) << "_" << subgroups::getFormatNameForGLSL(format)
667                                           << "_" << getShaderStageName(stage);
668
669                                 addFunctionCaseWithPrograms(group.get(), name.str(),
670                                                                                         "", initPrograms, test, caseDef);
671
672                                 if (VK_SHADER_STAGE_VERTEX_BIT == stage )
673                                 {
674                                         caseDef.noSSBO = true;
675                                         addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
676                                                                 initFrameBufferPrograms, test, caseDef);
677                                 }
678                         }
679                 }
680         }
681
682         return group.release();
683 }
684
685 } // subgroups
686 } // vkt