341473913894ab7bb267c7db711fbd685849819a
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsVoteTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsVoteTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_ALL = 0,
41         OPTYPE_ANY,
42         OPTYPE_ALLEQUAL,
43         OPTYPE_LAST
44 };
45
46 static bool checkVertexPipelineStages(std::vector<const void*> datas,
47                                                                           deUint32 width, deUint32)
48 {
49         const deUint32* data =
50                 reinterpret_cast<const deUint32*>(datas[0]);
51         for (deUint32 x = 0; x < width; ++x)
52         {
53                 deUint32 val = data[x];
54
55                 if (0x7 != val)
56                 {
57                         return false;
58                 }
59         }
60
61         return true;
62 }
63
64 static bool checkVertexPipelineStagesNoSSBO(std::vector<const void*> datas,
65                                                                           deUint32 width, deUint32)
66 {
67         const float* data =
68                 reinterpret_cast<const float*>(datas[0]);
69         for (deUint32 x = 0; x < width; ++x)
70         {
71                 deUint32 val = static_cast<deUint32>(data[x]);
72
73                 if (0x7 != val)
74                 {
75                         return false;
76                 }
77         }
78
79         return true;
80 }
81
82 static bool checkFragment(std::vector<const void*> datas, deUint32 width,
83                                                   deUint32 height, deUint32)
84 {
85         const deUint32* data =
86                 reinterpret_cast<const deUint32*>(datas[0]);
87         for (deUint32 x = 0; x < width; ++x)
88         {
89                 for (deUint32 y = 0; y < height; ++y)
90                 {
91                         deUint32 val = data[x * height + y];
92
93                         if (0x7 != val)
94                         {
95                                 return false;
96                         }
97                 }
98         }
99
100         return true;
101 }
102
103 static bool checkCompute(std::vector<const void*> datas,
104                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
105                                                  deUint32)
106 {
107         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
108
109         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
110         {
111                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
112                 {
113                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
114                         {
115                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
116                                 {
117                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
118                                         {
119                                                 for (deUint32 lZ = 0; lZ < localSize[2];
120                                                                 ++lZ)
121                                                 {
122                                                         const deUint32 globalInvocationX =
123                                                                 nX * localSize[0] + lX;
124                                                         const deUint32 globalInvocationY =
125                                                                 nY * localSize[1] + lY;
126                                                         const deUint32 globalInvocationZ =
127                                                                 nZ * localSize[2] + lZ;
128
129                                                         const deUint32 globalSizeX =
130                                                                 numWorkgroups[0] * localSize[0];
131                                                         const deUint32 globalSizeY =
132                                                                 numWorkgroups[1] * localSize[1];
133
134                                                         const deUint32 offset =
135                                                                 globalSizeX *
136                                                                 ((globalSizeY *
137                                                                   globalInvocationZ) +
138                                                                  globalInvocationY) +
139                                                                 globalInvocationX;
140
141                                                         // The data should look (in binary) 0b111
142                                                         if (0x7 != data[offset])
143                                                         {
144                                                                 return false;
145                                                         }
146                                                 }
147                                         }
148                                 }
149                         }
150                 }
151         }
152
153         return true;
154 }
155
156 static bool checkComputeAllEqual(std::vector<const void*> datas,
157                                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
158                                                                  deUint32)
159 {
160         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
161
162         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
163         {
164                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
165                 {
166                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
167                         {
168                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
169                                 {
170                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
171                                         {
172                                                 for (deUint32 lZ = 0; lZ < localSize[2];
173                                                                 ++lZ)
174                                                 {
175                                                         const deUint32 globalInvocationX =
176                                                                 nX * localSize[0] + lX;
177                                                         const deUint32 globalInvocationY =
178                                                                 nY * localSize[1] + lY;
179                                                         const deUint32 globalInvocationZ =
180                                                                 nZ * localSize[2] + lZ;
181
182                                                         const deUint32 globalSizeX =
183                                                                 numWorkgroups[0] * localSize[0];
184                                                         const deUint32 globalSizeY =
185                                                                 numWorkgroups[1] * localSize[1];
186
187                                                         const deUint32 offset =
188                                                                 globalSizeX *
189                                                                 ((globalSizeY *
190                                                                   globalInvocationZ) +
191                                                                  globalInvocationY) +
192                                                                 globalInvocationX;
193
194                                                         // The data should look (in binary) 0b111
195                                                         if (0x7 != data[offset])
196                                                         {
197                                                                 return false;
198                                                         }
199                                                 }
200                                         }
201                                 }
202                         }
203                 }
204         }
205
206         return true;
207 }
208
209 std::string getOpTypeName(int opType)
210 {
211         switch (opType)
212         {
213                 default:
214                         DE_FATAL("Unsupported op type");
215                 case OPTYPE_ALL:
216                         return "subgroupAll";
217                 case OPTYPE_ANY:
218                         return "subgroupAny";
219                 case OPTYPE_ALLEQUAL:
220                         return "subgroupAllEqual";
221         }
222 }
223
224 struct CaseDefinition
225 {
226         int                                     opType;
227         VkShaderStageFlags      shaderStage;
228         VkFormat                        format;
229         bool                            noSSBO;
230 };
231
232 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
233 {
234         std::ostringstream      vertexSrc;
235         std::ostringstream      fragmentSrc;
236
237         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
238         {
239                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
240                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
241                         << "layout(location = 0) out vec4 out_color;\n"
242                         << "layout(location = 0) in highp vec4 in_position;\n"
243                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
244                         << "{\n"
245                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
246                         << "};\n"
247                         << "\n"
248                         << "void main (void)\n"
249                         << "{\n"
250                         << "  uint result;\n";
251                 if (OPTYPE_ALL == caseDef.opType)
252                 {
253                         vertexSrc << " result = " << getOpTypeName(caseDef.opType)
254                                 << "(true) ? 0x1 : 0;\n"
255                                 << "  result |= " << getOpTypeName(caseDef.opType)
256                                 << "(false) ? 0 : 0x2;\n"
257                                 << "  result |= 0x4;\n"
258                                 << "  out_color.r = float(result);\n";
259                 }
260                 else if (OPTYPE_ANY == caseDef.opType)
261                 {
262                         vertexSrc << "  result = " << getOpTypeName(caseDef.opType)
263                                 << "(true) ? 0x1 : 0;\n"
264                                 << "  result |= " << getOpTypeName(caseDef.opType)
265                                 << "(false) ? 0 : 0x2;\n"
266                                 << "  result |= 0x4;\n"
267                                 << "out_color.r = float(result);\n";
268                 }
269                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
270                 {
271                         vertexSrc << "  result = " << getOpTypeName(caseDef.opType) << "("
272                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
273                                 << "  result |= " << getOpTypeName(caseDef.opType)
274                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
275                                 << "  if (subgroupElect()) result |= 0x2;\n"
276                                 << "  result |= " << getOpTypeName(caseDef.opType)
277                                 << "(data[0]) ? 0x4 : 0;\n"
278                                 << "  out_color.x = float(result);\n";
279                 }
280
281                 vertexSrc << "  gl_Position = in_position;\n"
282                         << "  gl_PointSize = 1.0f;\n"
283                         << "}\n";
284
285                 programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
286
287                 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
288                         << "layout(location = 0) in vec4 in_color;\n"
289                         << "layout(location = 0) out vec4 out_color;\n"
290                         << "void main()\n"
291                         <<"{\n"
292                         << "    out_color = in_color;\n"
293                         << "}\n";
294                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
295         }
296         else
297         {
298                 DE_FATAL("Unsupported shader stage");
299         }
300 }
301
302 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
303 {
304         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
305         {
306                 std::ostringstream src;
307
308                 src << "#version 450\n"
309                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
310                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
311                         "local_size_z_id = 2) in;\n"
312                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
313                         << "{\n"
314                         << "  uint result[];\n"
315                         << "};\n"
316                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
317                         << "{\n"
318                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
319                         << "};\n"
320                         << "\n"
321                         << "void main (void)\n"
322                         << "{\n"
323                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
324                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
325                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
326                         "gl_GlobalInvocationID.x;\n";
327                 if (OPTYPE_ALL == caseDef.opType)
328                 {
329                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
330                                 << "(true) ? 0x1 : 0;\n"
331                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
332                                 << "(false) ? 0 : 0x2;\n"
333                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
334                                 << "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
335                 }
336                 else if (OPTYPE_ANY == caseDef.opType)
337                 {
338                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
339                                 << "(true) ? 0x1 : 0;\n"
340                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
341                                 << "(false) ? 0 : 0x2;\n"
342                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
343                                 << "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
344                 }
345                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
346                 {
347                         src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
348                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
349                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
350                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
351                                 << "  if (subgroupElect()) result[offset] |= 0x2;\n"
352                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
353                                 << "(data[0]) ? 0x4 : 0;\n";
354                 }
355
356                 src << "}\n";
357
358                 programCollection.glslSources.add("comp")
359                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
360         }
361         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
362         {
363                 programCollection.glslSources.add("vert")
364                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
365
366                 std::ostringstream frag;
367
368                 frag << "#version 450\n"
369                          << "#extension GL_KHR_shader_subgroup_vote: enable\n"
370                          << "layout(location = 0) out uint result;\n"
371                          << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n"
372                          << "{\n"
373                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
374                          << "};\n"
375                          << "void main (void)\n"
376                          << "{\n";
377                 if (OPTYPE_ALL == caseDef.opType)
378                 {
379                         frag << "  result = " << getOpTypeName(caseDef.opType)
380                                  << "(true) ? 0x1 : 0;\n"
381                                  << "  result |= " << getOpTypeName(caseDef.opType)
382                                  << "(false) ? 0 : 0x2;\n"
383                                  << "  result |= " << getOpTypeName(caseDef.opType)
384                                  << "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
385                 }
386                 else if (OPTYPE_ANY == caseDef.opType)
387                 {
388                         frag << "  result = " << getOpTypeName(caseDef.opType)
389                                  << "(true) ? 0x1 : 0;\n"
390                                  << "  result |= " << getOpTypeName(caseDef.opType)
391                                  << "(false) ? 0 : 0x2;\n"
392                                  << "  result |= " << getOpTypeName(caseDef.opType)
393                                  << "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
394                 }
395                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
396                 {
397                         frag << "  result = " << getOpTypeName(caseDef.opType) << "("
398                                  << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
399                                  << "  result |= " << getOpTypeName(caseDef.opType)
400                                  << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
401                                  << "  if (subgroupElect()) result |= 0x2;\n"
402                                  << "  result |= " << getOpTypeName(caseDef.opType)
403                                  << "(data[0]) ? 0x4 : 0;\n";
404                 }
405                 frag << "}\n";
406
407                 programCollection.glslSources.add("frag")
408                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
409         }
410         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
411         {
412                 std::ostringstream src;
413
414                 src << "#version 450\n"
415                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
416                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
417                         << "{\n"
418                         << "  uint result[];\n"
419                         << "};\n"
420                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
421                         << "{\n"
422                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
423                         << "};\n"
424                         << "\n"
425                         << "void main (void)\n"
426                         << "{\n"
427                         << "  highp uint offset = gl_VertexIndex;\n";
428                 if (OPTYPE_ALL == caseDef.opType)
429                 {
430                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
431                                 << "(true) ? 0x1 : 0;\n"
432                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
433                                 << "(false) ? 0 : 0x2;\n"
434                                 << "  result[offset] |= 0x4;\n";
435                 }
436                 else if (OPTYPE_ANY == caseDef.opType)
437                 {
438                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
439                                 << "(true) ? 0x1 : 0;\n"
440                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
441                                 << "(false) ? 0 : 0x2;\n"
442                                 << "  result[offset] |= 0x4;\n";
443                 }
444                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
445                 {
446                         src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
447                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
448                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
449                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
450                                 << "  if (subgroupElect()) result[offset] |= 0x2;\n"
451                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
452                                 << "(data[0]) ? 0x4 : 0;\n";
453                 }
454
455                 src << "}\n";
456
457                 programCollection.glslSources.add("vert")
458                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
459         }
460         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
461         {
462                 programCollection.glslSources.add("vert")
463                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
464
465                 std::ostringstream src;
466
467                 src << "#version 450\n"
468                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
469                         << "layout(points) in;\n"
470                         << "layout(points, max_vertices = 1) out;\n"
471                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
472                         << "{\n"
473                         << "  uint result[];\n"
474                         << "};\n"
475                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
476                         << "{\n"
477                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
478                         << "};\n"
479                         << "\n"
480                         << "void main (void)\n"
481                         << "{\n"
482                         << "  highp uint offset = gl_PrimitiveIDIn;\n";
483                 if (OPTYPE_ALL == caseDef.opType)
484                 {
485                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
486                                 << "(true) ? 0x1 : 0;\n"
487                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
488                                 << "(false) ? 0 : 0x2;\n"
489                                 << "  result[offset] |= 0x4;\n";
490                 }
491                 else if (OPTYPE_ANY == caseDef.opType)
492                 {
493                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
494                                 << "(true) ? 0x1 : 0;\n"
495                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
496                                 << "(false) ? 0 : 0x2;\n"
497                                 << "  result[offset] |= 0x4;\n";
498                 }
499                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
500                 {
501                         src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
502                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
503                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
504                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
505                                 << "  if (subgroupElect()) result[offset] |= 0x2;\n"
506                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
507                                 << "(data[0]) ? 0x4 : 0;\n";
508                 }
509
510                 src << "}\n";
511
512                 programCollection.glslSources.add("geom")
513                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
514         }
515         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
516         {
517                 programCollection.glslSources.add("vert")
518                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
519
520                 programCollection.glslSources.add("tese")
521                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
522
523                 std::ostringstream src;
524
525                 src << "#version 450\n"
526                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
527                         << "layout(vertices=1) out;\n"
528                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
529                         << "{\n"
530                         << "  uint result[];\n"
531                         << "};\n"
532                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
533                         << "{\n"
534                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
535                         << "};\n"
536                         << "\n"
537                         << "void main (void)\n"
538                         << "{\n"
539                         << "  highp uint offset = gl_PrimitiveID;\n";
540                 if (OPTYPE_ALL == caseDef.opType)
541                 {
542                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
543                                 << "(true) ? 0x1 : 0;\n"
544                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
545                                 << "(false) ? 0 : 0x2;\n"
546                                 << "  result[offset] |= 0x4;\n";
547                 }
548                 else if (OPTYPE_ANY == caseDef.opType)
549                 {
550                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
551                                 << "(true) ? 0x1 : 0;\n"
552                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
553                                 << "(false) ? 0 : 0x2;\n"
554                                 << "  result[offset] |= 0x4;\n";
555                 }
556                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
557                 {
558                         src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
559                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
560                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
561                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
562                                 << "  if (subgroupElect()) result[offset] |= 0x2;\n"
563                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
564                                 << "(data[0]) ? 0x4 : 0;\n";
565                 }
566
567                 src << "}\n";
568
569                 programCollection.glslSources.add("tesc")
570                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
571         }
572         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
573         {
574                 programCollection.glslSources.add("vert")
575                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
576
577                 programCollection.glslSources.add("tesc")
578                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
579
580                 std::ostringstream src;
581
582                 src << "#version 450\n"
583                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
584                         << "layout(isolines) in;\n"
585                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
586                         << "{\n"
587                         << "  uint result[];\n"
588                         << "};\n"
589                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
590                         << "{\n"
591                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
592                         << "};\n"
593                         << "\n"
594                         << "void main (void)\n"
595                         << "{\n"
596                         << "  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n";
597                 if (OPTYPE_ALL == caseDef.opType)
598                 {
599                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
600                                 << "(true) ? 0x1 : 0;\n"
601                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
602                                 << "(false) ? 0 : 0x2;\n"
603                                 << "  result[offset] |= 0x4;\n";
604                 }
605                 else if (OPTYPE_ANY == caseDef.opType)
606                 {
607                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
608                                 << "(true) ? 0x1 : 0;\n"
609                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
610                                 << "(false) ? 0 : 0x2;\n"
611                                 << "  result[offset] |= 0x4;\n";
612                 }
613                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
614                 {
615                         src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
616                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
617                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
618                                 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
619                                 << "  if (subgroupElect()) result[offset] |= 0x2;\n"
620                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
621                                 << "(data[0]) ? 0x4 : 0;\n";
622                 }
623
624                 src << "}\n";
625
626                 programCollection.glslSources.add("tese")
627                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
628         }
629         else
630         {
631                 DE_FATAL("Unsupported shader stage");
632         }
633 }
634
635 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
636 {
637         if (!subgroups::isSubgroupSupported(context))
638                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
639
640         if (!subgroups::areSubgroupOperationsSupportedForStage(
641                                 context, caseDef.shaderStage))
642         {
643                 if (subgroups::areSubgroupOperationsRequiredForStage(
644                                         caseDef.shaderStage))
645                 {
646                         return tcu::TestStatus::fail(
647                                            "Shader stage " +
648                                            subgroups::getShaderStageName(caseDef.shaderStage) +
649                                            " is required to support subgroup operations!");
650                 }
651                 else
652                 {
653                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
654                 }
655         }
656
657         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_VOTE_BIT))
658         {
659                 TCU_THROW(NotSupportedError, "Device does not support subgroup vote operations");
660         }
661
662         if (subgroups::isDoubleFormat(caseDef.format) &&
663                         !subgroups::isDoubleSupportedForDevice(context))
664         {
665                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
666         }
667
668         //Tests which don't use the SSBO
669         if (caseDef.noSSBO)
670         {
671                 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
672                 {
673                         subgroups::SSBOData inputData;
674                         inputData.format = caseDef.format;
675                         inputData.numElements = subgroups::maxSupportedSubgroupSize();
676                         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
677
678                         return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_SFLOAT, &inputData,
679                                                                                          1, checkVertexPipelineStagesNoSSBO);
680                 }
681         }
682
683         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
684                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
685         {
686                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
687                 {
688                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
689                 }
690         }
691
692         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
693         {
694                 subgroups::SSBOData inputData;
695                 inputData.format = caseDef.format;
696                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
697                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
698
699                 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
700                                                                                    &inputData, 1, checkFragment);
701         }
702         else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
703         {
704                 subgroups::SSBOData inputData;
705                 inputData.format = caseDef.format;
706                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
707                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
708
709                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
710                                                                                   1, (OPTYPE_ALLEQUAL == caseDef.opType) ? checkComputeAllEqual
711                                                                                   : checkCompute);
712         }
713         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
714         {
715                 subgroups::SSBOData inputData;
716                 inputData.format = caseDef.format;
717                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
718                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
719
720                 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData,
721                                                                                  1, checkVertexPipelineStages);
722         }
723         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
724         {
725                 subgroups::SSBOData inputData;
726                 inputData.format = caseDef.format;
727                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
728                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
729
730                 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData,
731                                                                                    1, checkVertexPipelineStages);
732         }
733         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
734         {
735                 subgroups::SSBOData inputData;
736                 inputData.format = caseDef.format;
737                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
738                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
739
740                 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData,
741                                 1, checkVertexPipelineStages);
742         }
743         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
744         {
745                 subgroups::SSBOData inputData;
746                 inputData.format = caseDef.format;
747                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
748                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
749
750                 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData,
751                                 1, checkVertexPipelineStages);
752         }
753         else
754         {
755                 TCU_THROW(InternalError, "Unhandled shader stage");
756         }
757 }
758 }
759
760 namespace vkt
761 {
762 namespace subgroups
763 {
764 tcu::TestCaseGroup* createSubgroupsVoteTests(tcu::TestContext& testCtx)
765 {
766         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
767                         testCtx, "vote", "Subgroup vote category tests"));
768
769         const VkShaderStageFlags stages[] =
770         {
771                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
772                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
773                 VK_SHADER_STAGE_GEOMETRY_BIT,
774                 VK_SHADER_STAGE_VERTEX_BIT,
775                 VK_SHADER_STAGE_FRAGMENT_BIT,
776                 VK_SHADER_STAGE_COMPUTE_BIT
777         };
778
779         const VkFormat formats[] =
780         {
781                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
782                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
783                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
784                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
785                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
786                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
787                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
788                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
789                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
790         };
791
792         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
793         {
794                 const VkShaderStageFlags stage = stages[stageIndex];
795
796                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
797                 {
798                         const VkFormat format = formats[formatIndex];
799
800                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
801                         {
802                                 // Skip the typed tests for all but subgroupAllEqual()
803                                 if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opTypeIndex))
804                                 {
805                                         continue;
806                                 }
807
808                                 CaseDefinition caseDef = {opTypeIndex, stage, format, false};
809
810                                 std::string op = getOpTypeName(opTypeIndex);
811
812                                 addFunctionCaseWithPrograms(group.get(),
813                                                                                         de::toLower(op) + "_" +
814                                                                                         subgroups::getFormatNameForGLSL(format)
815                                                                                         + "_" + getShaderStageName(stage),
816                                                                                         "", initPrograms, test, caseDef);
817
818                                 if (VK_SHADER_STAGE_VERTEX_BIT == stage )
819                                 {
820                                         caseDef.noSSBO = true;
821                                         addFunctionCaseWithPrograms(group.get(),
822                                                                 de::toLower(op) + "_" +
823                                                                 subgroups::getFormatNameForGLSL(format)
824                                                                 + "_" + getShaderStageName(stage)+"_framebuffer", "",
825                                                                 initFrameBufferPrograms, test, caseDef);
826                                 }
827
828                         }
829                 }
830         }
831
832         return group.release();
833 }
834
835 } // subgroups
836 } // vkt