d66cfec78288238a63365cccc51ad0fcf2222264
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsShuffleTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsShuffleTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_SHUFFLE = 0,
41         OPTYPE_SHUFFLE_XOR,
42         OPTYPE_SHUFFLE_UP,
43         OPTYPE_SHUFFLE_DOWN,
44         OPTYPE_LAST
45 };
46
47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48                                                                           deUint32 width, deUint32)
49 {
50         const deUint32* data =
51                 reinterpret_cast<const deUint32*>(datas[0]);
52         for (deUint32 x = 0; x < width; ++x)
53         {
54                 deUint32 val = data[x];
55
56                 if (0x1 != val)
57                 {
58                         return false;
59                 }
60         }
61
62         return true;
63 }
64
65 static bool checkFragment(std::vector<const void*> datas,
66                                                   deUint32 width, deUint32 height, deUint32)
67 {
68         const deUint32* data =
69                 reinterpret_cast<const deUint32*>(datas[0]);
70         for (deUint32 x = 0; x < width; ++x)
71         {
72                 for (deUint32 y = 0; y < height; ++y)
73                 {
74                         deUint32 val = data[x * height + y];
75
76                         if (0x1 != val)
77                         {
78                                 return false;
79                         }
80                 }
81         }
82
83         return true;
84 }
85
86 static bool checkCompute(std::vector<const void*> datas,
87                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
88                                                  deUint32)
89 {
90         const deUint32* data =
91                 reinterpret_cast<const deUint32*>(datas[0]);
92
93         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
94         {
95                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
96                 {
97                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
98                         {
99                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
100                                 {
101                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
102                                         {
103                                                 for (deUint32 lZ = 0; lZ < localSize[2];
104                                                                 ++lZ)
105                                                 {
106                                                         const deUint32 globalInvocationX =
107                                                                 nX * localSize[0] + lX;
108                                                         const deUint32 globalInvocationY =
109                                                                 nY * localSize[1] + lY;
110                                                         const deUint32 globalInvocationZ =
111                                                                 nZ * localSize[2] + lZ;
112
113                                                         const deUint32 globalSizeX =
114                                                                 numWorkgroups[0] * localSize[0];
115                                                         const deUint32 globalSizeY =
116                                                                 numWorkgroups[1] * localSize[1];
117
118                                                         const deUint32 offset =
119                                                                 globalSizeX *
120                                                                 ((globalSizeY *
121                                                                   globalInvocationZ) +
122                                                                  globalInvocationY) +
123                                                                 globalInvocationX;
124
125                                                         if (0x1 != data[offset])
126                                                         {
127                                                                 return false;
128                                                         }
129                                                 }
130                                         }
131                                 }
132                         }
133                 }
134         }
135
136         return true;
137 }
138
139 std::string getOpTypeName(int opType)
140 {
141         switch (opType)
142         {
143                 default:
144                         DE_FATAL("Unsupported op type");
145                 case OPTYPE_SHUFFLE:
146                         return "subgroupShuffle";
147                 case OPTYPE_SHUFFLE_XOR:
148                         return "subgroupShuffleXor";
149                 case OPTYPE_SHUFFLE_UP:
150                         return "subgroupShuffleUp";
151                 case OPTYPE_SHUFFLE_DOWN:
152                         return "subgroupShuffleDown";
153         }
154 }
155
156 struct CaseDefinition
157 {
158         int                                     opType;
159         VkShaderStageFlags      shaderStage;
160         VkFormat                        format;
161         bool                            noSSBO;
162 };
163
164 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
165 {
166         std::string idTable[OPTYPE_LAST];
167         idTable[OPTYPE_SHUFFLE]                 = "data2[gl_SubgroupInvocationID]";
168         idTable[OPTYPE_SHUFFLE_XOR]             = "gl_SubgroupInvocationID ^ data2[gl_SubgroupInvocationID]";
169         idTable[OPTYPE_SHUFFLE_UP]              = "gl_SubgroupInvocationID - data2[gl_SubgroupInvocationID]";
170         idTable[OPTYPE_SHUFFLE_DOWN]    = "gl_SubgroupInvocationID + data2[gl_SubgroupInvocationID]";
171
172         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
173         {
174                 std::ostringstream      src;
175                 std::ostringstream      fragmentSrc;
176
177                 src << "#version 450\n"
178                         << "layout(location = 0) in highp vec4 in_position;\n"
179                         << "layout(location = 0) out float result;\n";
180
181                 switch (caseDef.opType)
182                 {
183                         case OPTYPE_SHUFFLE:
184                         case OPTYPE_SHUFFLE_XOR:
185                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
186                                 break;
187                         default:
188                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
189                                 break;
190                 }
191
192                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
193                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
194                         << "{\n"
195                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
196                         << "};\n"
197                         << "layout(set = 0, binding = 1) uniform Buffer2\n"
198                         << "{\n"
199                         << "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
200                         << "};\n"
201                         << "\n"
202                         << "void main (void)\n"
203                         << "{\n"
204                         << "  uvec4 mask = subgroupBallot(true);\n"
205                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
206                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
207                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
208                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
209                         << "  {\n"
210                         << "    result = (op == data1[id]) ? 1.0f : 0.0f;\n"
211                         << "  }\n"
212                         << "  else\n"
213                         << "  {\n"
214                         << "    result = 1.0f; // Invocation we read from was inactive, so we can't verify results!\n"
215                         << "  }\n"
216                         << "  gl_Position = in_position;\n"
217                         << "  gl_PointSize = 1.0f;\n"
218                         << "}\n";
219
220                 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
221
222                 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
223                         << "layout(location = 0) in float result;\n"
224                         << "layout(location = 0) out uint out_color;\n"
225                         << "void main()\n"
226                         <<"{\n"
227                         << "    out_color = uint(result);\n"
228                         << "}\n";
229                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
230         }
231         else
232         {
233                 DE_FATAL("Unsupported shader stage");
234         }
235 }
236
237 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
238 {
239         std::string idTable[OPTYPE_LAST];
240         idTable[OPTYPE_SHUFFLE] = "data2[gl_SubgroupInvocationID]";
241         idTable[OPTYPE_SHUFFLE_XOR] = "gl_SubgroupInvocationID ^ data2[gl_SubgroupInvocationID]";
242         idTable[OPTYPE_SHUFFLE_UP] = "gl_SubgroupInvocationID - data2[gl_SubgroupInvocationID]";
243         idTable[OPTYPE_SHUFFLE_DOWN] = "gl_SubgroupInvocationID + data2[gl_SubgroupInvocationID]";
244
245         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
246         {
247                 std::ostringstream src;
248
249                 src << "#version 450\n";
250
251                 switch (caseDef.opType)
252                 {
253                         case OPTYPE_SHUFFLE:
254                         case OPTYPE_SHUFFLE_XOR:
255                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
256                                 break;
257                         default:
258                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
259                                 break;
260                 }
261
262                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
263                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
264                         "local_size_z_id = 2) in;\n"
265                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
266                         << "{\n"
267                         << "  uint result[];\n"
268                         << "};\n"
269                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
270                         << "{\n"
271                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
272                         << "};\n"
273                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
274                         << "{\n"
275                         << "  uint data2[];\n"
276                         << "};\n"
277                         << "\n"
278                         << "void main (void)\n"
279                         << "{\n"
280                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
281                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
282                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
283                         "gl_GlobalInvocationID.x;\n"
284                         << "  uvec4 mask = subgroupBallot(true);\n"
285                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
286                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
287                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
288                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
289                         << "  {\n"
290                         << "    result[offset] = (op == data1[id]) ? 1 : 0;\n"
291                         << "  }\n"
292                         << "  else\n"
293                         << "  {\n"
294                         << "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
295                         << "  }\n"
296                         << "}\n";
297
298                 programCollection.glslSources.add("comp")
299                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
300         }
301         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
302         {
303                 programCollection.glslSources.add("vert")
304                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
305
306                 std::ostringstream frag;
307
308                 frag << "#version 450\n";
309
310                 switch (caseDef.opType)
311                 {
312                         case OPTYPE_SHUFFLE:
313                         case OPTYPE_SHUFFLE_XOR:
314                                 frag << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
315                                 break;
316                         default:
317                                 frag << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
318                                 break;
319                 }
320
321                 frag << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
322                          << "layout(location = 0) out uint result;\n"
323                          << "layout(set = 0, binding = 0, std430) readonly buffer Buffer1\n"
324                          << "{\n"
325                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
326                          << "};\n"
327                          << "layout(set = 0, binding = 1, std430) readonly buffer Buffer2\n"
328                          << "{\n"
329                          << "  uint data2[];\n"
330                          << "};\n"
331                          << "void main (void)\n"
332                          << "{\n"
333                          << "  uvec4 mask = subgroupBallot(true);\n"
334                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
335                          << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
336                          << "  uint id = " << idTable[caseDef.opType] << ";\n"
337                          << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
338                          << "  {\n"
339                          << "    result = (op == data1[id]) ? 1 : 0;\n"
340                          << "  }\n"
341                          << "  else\n"
342                          << "  {\n"
343                          << "    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
344                          << "  }\n"
345                          << "}\n";
346
347                 programCollection.glslSources.add("frag")
348                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
349         }
350         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
351         {
352                 std::ostringstream src;
353
354                 src << "#version 450\n";
355
356                 switch (caseDef.opType)
357                 {
358                         case OPTYPE_SHUFFLE:
359                         case OPTYPE_SHUFFLE_XOR:
360                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
361                                 break;
362                         default:
363                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
364                                 break;
365                 }
366
367                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
368                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
369                         << "{\n"
370                         << "  uint result[];\n"
371                         << "};\n"
372                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
373                         << "{\n"
374                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
375                         << "};\n"
376                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
377                         << "{\n"
378                         << "  uint data2[];\n"
379                         << "};\n"
380                         << "\n"
381                         << "void main (void)\n"
382                         << "{\n"
383                         << "  uvec4 mask = subgroupBallot(true);\n"
384                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
385                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
386                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
387                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
388                         << "  {\n"
389                         << "    result[gl_VertexIndex] = (op == data1[id]) ? 1 : 0;\n"
390                         << "  }\n"
391                         << "  else\n"
392                         << "  {\n"
393                         << "    result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
394                         << "  }\n"
395                         << "}\n";
396
397                 programCollection.glslSources.add("vert")
398                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
399         }
400         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
401         {
402                 programCollection.glslSources.add("vert")
403                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
404
405                 std::ostringstream src;
406
407                 src << "#version 450\n";
408
409                 switch (caseDef.opType)
410                 {
411                         case OPTYPE_SHUFFLE:
412                         case OPTYPE_SHUFFLE_XOR:
413                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
414                                 break;
415                         default:
416                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
417                                 break;
418                 }
419
420                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
421                         << "layout(points) in;\n"
422                         << "layout(points, max_vertices = 1) out;\n"
423                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
424                         << "{\n"
425                         << "  uint result[];\n"
426                         << "};\n"
427                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
428                         << "{\n"
429                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
430                         << "};\n"
431                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
432                         << "{\n"
433                         << "  uint data2[];\n"
434                         << "};\n"
435                         << "\n"
436                         << "void main (void)\n"
437                         << "{\n"
438                         << "  uvec4 mask = subgroupBallot(true);\n"
439                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
440                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
441                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
442                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
443                         << "  {\n"
444                         << "    result[gl_PrimitiveIDIn] = (op == data1[id]) ? 1 : 0;\n"
445                         << "  }\n"
446                         << "  else\n"
447                         << "  {\n"
448                         << "    result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
449                         << "  }\n"
450                         << "}\n";
451
452                 programCollection.glslSources.add("geom")
453                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
454         }
455         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
456         {
457                 programCollection.glslSources.add("vert")
458                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
459
460                 programCollection.glslSources.add("tese")
461                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
462
463                 std::ostringstream src;
464
465                 src << "#version 450\n";
466
467                 switch (caseDef.opType)
468                 {
469                         case OPTYPE_SHUFFLE:
470                         case OPTYPE_SHUFFLE_XOR:
471                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
472                                 break;
473                         default:
474                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
475                                 break;
476                 }
477
478                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
479                         << "layout(vertices=1) out;\n"
480                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
481                         << "{\n"
482                         << "  uint result[];\n"
483                         << "};\n"
484                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
485                         << "{\n"
486                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
487                         << "};\n"
488                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
489                         << "{\n"
490                         << "  uint data2[];\n"
491                         << "};\n"
492                         << "\n"
493                         << "void main (void)\n"
494                         << "{\n"
495                         << "  uvec4 mask = subgroupBallot(true);\n"
496                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
497                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
498                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
499                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
500                         << "  {\n"
501                         << "    result[gl_PrimitiveID] = (op == data1[id]) ? 1 : 0;\n"
502                         << "  }\n"
503                         << "  else\n"
504                         << "  {\n"
505                         << "    result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
506                         << "  }\n"
507                         << "}\n";
508
509                 programCollection.glslSources.add("tesc")
510                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
511         }
512         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
513         {
514                 programCollection.glslSources.add("vert")
515                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
516
517                 programCollection.glslSources.add("tesc")
518                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
519
520                 std::ostringstream src;
521
522                 src << "#version 450\n";
523
524                 switch (caseDef.opType)
525                 {
526                         case OPTYPE_SHUFFLE:
527                         case OPTYPE_SHUFFLE_XOR:
528                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
529                                 break;
530                         default:
531                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
532                                 break;
533                 }
534
535                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
536                         << "layout(isolines) in;\n"
537                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
538                         << "{\n"
539                         << "  uint result[];\n"
540                         << "};\n"
541                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
542                         << "{\n"
543                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
544                         << "};\n"
545                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
546                         << "{\n"
547                         << "  uint data2[];\n"
548                         << "};\n"
549                         << "\n"
550                         << "void main (void)\n"
551                         << "{\n"
552                         << "  uvec4 mask = subgroupBallot(true);\n"
553                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
554                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
555                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
556                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
557                         << "  {\n"
558                         << "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data1[id]) ? 1 : 0;\n"
559                         << "  }\n"
560                         << "  else\n"
561                         << "  {\n"
562                         << "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
563                         << "  }\n"
564                         << "}\n";
565
566                 programCollection.glslSources.add("tese")
567                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
568         }
569         else
570         {
571                 DE_FATAL("Unsupported shader stage");
572         }
573 }
574
575 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
576 {
577         if (!subgroups::isSubgroupSupported(context))
578                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
579
580         if (!subgroups::areSubgroupOperationsSupportedForStage(
581                                 context, caseDef.shaderStage))
582         {
583                 if (subgroups::areSubgroupOperationsRequiredForStage(
584                                         caseDef.shaderStage))
585                 {
586                         return tcu::TestStatus::fail(
587                                            "Shader stage " +
588                                            subgroups::getShaderStageName(caseDef.shaderStage) +
589                                            " is required to support subgroup operations!");
590                 }
591                 else
592                 {
593                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
594                 }
595         }
596
597         switch (caseDef.opType)
598         {
599                 case OPTYPE_SHUFFLE:
600                 case OPTYPE_SHUFFLE_XOR:
601                         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_BIT))
602                         {
603                                 TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle operations");
604                         }
605                         break;
606                 default:
607                         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT))
608                         {
609                                 TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle relative operations");
610                         }
611                         break;
612         }
613
614         if (subgroups::isDoubleFormat(caseDef.format) &&
615                         !subgroups::isDoubleSupportedForDevice(context))
616         {
617                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
618         }
619
620         //Tests which don't use the SSBO
621         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
622         {
623                 subgroups::SSBOData inputData[2];
624                 inputData[0].format = caseDef.format;
625                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
626                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
627
628                 inputData[1].format = VK_FORMAT_R32_UINT;
629                 inputData[1].numElements = inputData[0].numElements;
630                 inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
631
632                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT,  inputData, 2, checkVertexPipelineStages);
633         }
634
635         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
636                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
637         {
638                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
639                 {
640                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
641                 }
642         }
643
644         subgroups::SSBOData inputData[2];
645         inputData[0].format = caseDef.format;
646         inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
647         inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
648
649         inputData[1].format = VK_FORMAT_R32_UINT;
650         inputData[1].numElements = inputData[0].numElements;
651         inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
652
653         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
654         {
655                 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
656                                                                                    inputData, 2, checkFragment);
657         }
658         else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
659         {
660                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
661                                                                                   inputData, 2, checkCompute);
662         }
663         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
664         {
665                 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
666                                                                                  inputData, 2, checkVertexPipelineStages);
667         }
668         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
669         {
670                 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
671                                                                                    inputData, 2, checkVertexPipelineStages);
672         }
673         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
674         {
675                 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
676                                 inputData, 2, checkVertexPipelineStages);
677         }
678         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
679         {
680                 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
681                                 inputData, 2, checkVertexPipelineStages);
682         }
683         else
684         {
685                 TCU_THROW(InternalError, "Unhandled shader stage");
686         }
687 }
688 }
689
690 namespace vkt
691 {
692 namespace subgroups
693 {
694 tcu::TestCaseGroup* createSubgroupsShuffleTests(tcu::TestContext& testCtx)
695 {
696         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
697                         testCtx, "shuffle", "Subgroup shuffle category tests"));
698
699         const VkShaderStageFlags stages[] =
700         {
701                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
702                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
703                 VK_SHADER_STAGE_GEOMETRY_BIT,
704                 VK_SHADER_STAGE_VERTEX_BIT,
705                 VK_SHADER_STAGE_FRAGMENT_BIT,
706                 VK_SHADER_STAGE_COMPUTE_BIT
707         };
708
709         const VkFormat formats[] =
710         {
711                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
712                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
713                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
714                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
715                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
716                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
717                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
718                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
719                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
720         };
721
722         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
723         {
724                 const VkShaderStageFlags stage = stages[stageIndex];
725
726                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
727                 {
728                         const VkFormat format = formats[formatIndex];
729
730                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
731                         {
732                                 CaseDefinition caseDef = {opTypeIndex, stage, format, false};
733
734                                 std::ostringstream name;
735
736                                 std::string op = getOpTypeName(opTypeIndex);
737
738                                 name << de::toLower(op)
739                                          << "_" << subgroups::getFormatNameForGLSL(format)
740                                          << "_" << getShaderStageName(stage);
741
742                                 addFunctionCaseWithPrograms(group.get(), name.str(),
743                                                                                         "", initPrograms, test, caseDef);
744
745                                 if (VK_SHADER_STAGE_VERTEX_BIT == stage )
746                                 {
747                                         caseDef.noSSBO = true;
748                                         addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
749                                                                                                 initFrameBufferPrograms, test, caseDef);
750                                 }
751                         }
752                 }
753         }
754
755         return group.release();
756 }
757
758 } // subgroups
759 } // vkt