Merge remote-tracking branch 'goog/upstream-vulkan-cts-next' into vulkan-cts-1.1...
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsShuffleTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsShuffleTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_SHUFFLE = 0,
41         OPTYPE_SHUFFLE_XOR,
42         OPTYPE_SHUFFLE_UP,
43         OPTYPE_SHUFFLE_DOWN,
44         OPTYPE_LAST
45 };
46
47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48                                                                           deUint32 width, deUint32)
49 {
50         const deUint32* data =
51                 reinterpret_cast<const deUint32*>(datas[0]);
52         for (deUint32 x = 0; x < width; ++x)
53         {
54                 deUint32 val = data[x];
55
56                 if (0x1 != val)
57                 {
58                         return false;
59                 }
60         }
61
62         return true;
63 }
64
65 static bool checkFragment(std::vector<const void*> datas,
66                                                   deUint32 width, deUint32 height, deUint32)
67 {
68         const deUint32* data =
69                 reinterpret_cast<const deUint32*>(datas[0]);
70         for (deUint32 x = 0; x < width; ++x)
71         {
72                 for (deUint32 y = 0; y < height; ++y)
73                 {
74                         deUint32 val = data[x * height + y];
75
76                         if (0x1 != val)
77                         {
78                                 return false;
79                         }
80                 }
81         }
82
83         return true;
84 }
85
86 static bool checkCompute(std::vector<const void*> datas,
87                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
88                                                  deUint32)
89 {
90         const deUint32* data =
91                 reinterpret_cast<const deUint32*>(datas[0]);
92
93         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
94         {
95                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
96                 {
97                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
98                         {
99                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
100                                 {
101                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
102                                         {
103                                                 for (deUint32 lZ = 0; lZ < localSize[2];
104                                                                 ++lZ)
105                                                 {
106                                                         const deUint32 globalInvocationX =
107                                                                 nX * localSize[0] + lX;
108                                                         const deUint32 globalInvocationY =
109                                                                 nY * localSize[1] + lY;
110                                                         const deUint32 globalInvocationZ =
111                                                                 nZ * localSize[2] + lZ;
112
113                                                         const deUint32 globalSizeX =
114                                                                 numWorkgroups[0] * localSize[0];
115                                                         const deUint32 globalSizeY =
116                                                                 numWorkgroups[1] * localSize[1];
117
118                                                         const deUint32 offset =
119                                                                 globalSizeX *
120                                                                 ((globalSizeY *
121                                                                   globalInvocationZ) +
122                                                                  globalInvocationY) +
123                                                                 globalInvocationX;
124
125                                                         if (0x1 != data[offset])
126                                                         {
127                                                                 return false;
128                                                         }
129                                                 }
130                                         }
131                                 }
132                         }
133                 }
134         }
135
136         return true;
137 }
138
139 std::string getOpTypeName(int opType)
140 {
141         switch (opType)
142         {
143                 default:
144                         DE_FATAL("Unsupported op type");
145                 case OPTYPE_SHUFFLE:
146                         return "subgroupShuffle";
147                 case OPTYPE_SHUFFLE_XOR:
148                         return "subgroupShuffleXor";
149                 case OPTYPE_SHUFFLE_UP:
150                         return "subgroupShuffleUp";
151                 case OPTYPE_SHUFFLE_DOWN:
152                         return "subgroupShuffleDown";
153         }
154 }
155
156 struct CaseDefinition
157 {
158         int                                     opType;
159         VkShaderStageFlags      shaderStage;
160         VkFormat                        format;
161         bool                            noSSBO;
162 };
163
164 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
165 {
166         std::string idTable[OPTYPE_LAST];
167         idTable[OPTYPE_SHUFFLE]                 = "data2[gl_SubgroupInvocationID]";
168         idTable[OPTYPE_SHUFFLE_XOR]             = "gl_SubgroupInvocationID ^ data2[gl_SubgroupInvocationID]";
169         idTable[OPTYPE_SHUFFLE_UP]              = "gl_SubgroupInvocationID - data2[gl_SubgroupInvocationID]";
170         idTable[OPTYPE_SHUFFLE_DOWN]    = "gl_SubgroupInvocationID + data2[gl_SubgroupInvocationID]";
171
172         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
173         {
174                 std::ostringstream      src;
175                 std::ostringstream      fragmentSrc;
176
177                 src << "#version 450\n"
178                         << "layout(location = 0) in highp vec4 in_position;\n"
179                         << "layout(location = 0) out float result;\n";
180
181                 switch (caseDef.opType)
182                 {
183                         case OPTYPE_SHUFFLE:
184                         case OPTYPE_SHUFFLE_XOR:
185                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
186                                 break;
187                         default:
188                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
189                                 break;
190                 }
191
192                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
193                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
194                         << "{\n"
195                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
196                         << "};\n"
197                         << "layout(set = 0, binding = 1) uniform Buffer2\n"
198                         << "{\n"
199                         << "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
200                         << "};\n"
201                         << "\n"
202                         << "void main (void)\n"
203                         << "{\n"
204                         << "  uvec4 mask = subgroupBallot(true);\n"
205                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
206                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
207                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
208                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
209                         << "  {\n"
210                         << "    result = (op == data1[id]) ? 1.0f : 0.0f;\n"
211                         << "  }\n"
212                         << "  else\n"
213                         << "  {\n"
214                         << "    result = 1.0f; // Invocation we read from was inactive, so we can't verify results!\n"
215                         << "  }\n"
216                         << "  gl_Position = in_position;\n"
217                         << "  gl_PointSize = 1.0f;\n"
218                         << "}\n";
219
220                 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
221
222                 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
223                         << "layout(location = 0) in float result;\n"
224                         << "layout(location = 0) out uint out_color;\n"
225                         << "void main()\n"
226                         <<"{\n"
227                         << "    out_color = uint(result);\n"
228                         << "}\n";
229                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
230         }
231         else
232         {
233                 DE_FATAL("Unsupported shader stage");
234         }
235 }
236
237 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
238 {
239         std::string idTable[OPTYPE_LAST];
240         idTable[OPTYPE_SHUFFLE] = "data2[gl_SubgroupInvocationID]";
241         idTable[OPTYPE_SHUFFLE_XOR] = "gl_SubgroupInvocationID ^ data2[gl_SubgroupInvocationID]";
242         idTable[OPTYPE_SHUFFLE_UP] = "gl_SubgroupInvocationID - data2[gl_SubgroupInvocationID]";
243         idTable[OPTYPE_SHUFFLE_DOWN] = "gl_SubgroupInvocationID + data2[gl_SubgroupInvocationID]";
244
245         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
246         {
247                 std::ostringstream src;
248
249                 src << "#version 450\n";
250
251                 switch (caseDef.opType)
252                 {
253                         case OPTYPE_SHUFFLE:
254                         case OPTYPE_SHUFFLE_XOR:
255                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
256                                 break;
257                         default:
258                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
259                                 break;
260                 }
261
262                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
263                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
264                         "local_size_z_id = 2) in;\n"
265                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
266                         << "{\n"
267                         << "  uint result[];\n"
268                         << "};\n"
269                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
270                         << "{\n"
271                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
272                         << "};\n"
273                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
274                         << "{\n"
275                         << "  uint data2[];\n"
276                         << "};\n"
277                         << "\n"
278                         << "void main (void)\n"
279                         << "{\n"
280                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
281                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
282                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
283                         "gl_GlobalInvocationID.x;\n"
284                         << "  uvec4 mask = subgroupBallot(true);\n"
285                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
286                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
287                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
288                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
289                         << "  {\n"
290                         << "    result[offset] = (op == data1[id]) ? 1 : 0;\n"
291                         << "  }\n"
292                         << "  else\n"
293                         << "  {\n"
294                         << "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
295                         << "  }\n"
296                         << "}\n";
297
298                 programCollection.glslSources.add("comp")
299                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
300         }
301         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
302         {
303                 programCollection.glslSources.add("vert")
304                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
305
306                 std::ostringstream frag;
307
308                 frag << "#version 450\n";
309
310                 switch (caseDef.opType)
311                 {
312                         case OPTYPE_SHUFFLE:
313                         case OPTYPE_SHUFFLE_XOR:
314                                 frag << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
315                                 break;
316                         default:
317                                 frag << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
318                                 break;
319                 }
320
321                 frag << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
322                          << "layout(location = 0) out uint result;\n"
323                          << "layout(set = 0, binding = 0, std430) readonly buffer Buffer1\n"
324                          << "{\n"
325                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
326                          << "};\n"
327                          << "layout(set = 0, binding = 1, std430) readonly buffer Buffer2\n"
328                          << "{\n"
329                          << "  uint data2[];\n"
330                          << "};\n"
331                          << "void main (void)\n"
332                          << "{\n"
333                          << "  uvec4 mask = subgroupBallot(true);\n"
334                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
335                          << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
336                          << "  uint id = " << idTable[caseDef.opType] << ";\n"
337                          << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
338                          << "  {\n"
339                          << "    result = (op == data1[id]) ? 1 : 0;\n"
340                          << "  }\n"
341                          << "  else\n"
342                          << "  {\n"
343                          << "    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
344                          << "  }\n"
345                          << "}\n";
346
347                 programCollection.glslSources.add("frag")
348                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
349         }
350         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
351         {
352                 std::ostringstream src;
353
354                 src << "#version 450\n";
355
356                 switch (caseDef.opType)
357                 {
358                         case OPTYPE_SHUFFLE:
359                         case OPTYPE_SHUFFLE_XOR:
360                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
361                                 break;
362                         default:
363                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
364                                 break;
365                 }
366
367                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
368                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
369                         << "{\n"
370                         << "  uint result[];\n"
371                         << "};\n"
372                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
373                         << "{\n"
374                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
375                         << "};\n"
376                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
377                         << "{\n"
378                         << "  uint data2[];\n"
379                         << "};\n"
380                         << "\n"
381                         << "void main (void)\n"
382                         << "{\n"
383                         << "  uvec4 mask = subgroupBallot(true);\n"
384                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
385                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
386                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
387                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
388                         << "  {\n"
389                         << "    result[gl_VertexIndex] = (op == data1[id]) ? 1 : 0;\n"
390                         << "  }\n"
391                         << "  else\n"
392                         << "  {\n"
393                         << "    result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
394                         << "  }\n"
395                         << "  gl_PointSize = 1.0f;\n"
396                         << "}\n";
397
398                 programCollection.glslSources.add("vert")
399                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
400         }
401         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
402         {
403                 programCollection.glslSources.add("vert")
404                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
405
406                 std::ostringstream src;
407
408                 src << "#version 450\n";
409
410                 switch (caseDef.opType)
411                 {
412                         case OPTYPE_SHUFFLE:
413                         case OPTYPE_SHUFFLE_XOR:
414                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
415                                 break;
416                         default:
417                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
418                                 break;
419                 }
420
421                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
422                         << "layout(points) in;\n"
423                         << "layout(points, max_vertices = 1) out;\n"
424                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
425                         << "{\n"
426                         << "  uint result[];\n"
427                         << "};\n"
428                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
429                         << "{\n"
430                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
431                         << "};\n"
432                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
433                         << "{\n"
434                         << "  uint data2[];\n"
435                         << "};\n"
436                         << "\n"
437                         << "void main (void)\n"
438                         << "{\n"
439                         << "  uvec4 mask = subgroupBallot(true);\n"
440                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
441                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
442                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
443                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
444                         << "  {\n"
445                         << "    result[gl_PrimitiveIDIn] = (op == data1[id]) ? 1 : 0;\n"
446                         << "  }\n"
447                         << "  else\n"
448                         << "  {\n"
449                         << "    result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
450                         << "  }\n"
451                         << "}\n";
452
453                 programCollection.glslSources.add("geom")
454                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
455         }
456         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
457         {
458                 programCollection.glslSources.add("vert")
459                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
460
461                 programCollection.glslSources.add("tese")
462                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
463
464                 std::ostringstream src;
465
466                 src << "#version 450\n";
467
468                 switch (caseDef.opType)
469                 {
470                         case OPTYPE_SHUFFLE:
471                         case OPTYPE_SHUFFLE_XOR:
472                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
473                                 break;
474                         default:
475                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
476                                 break;
477                 }
478
479                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
480                         << "layout(vertices=1) out;\n"
481                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
482                         << "{\n"
483                         << "  uint result[];\n"
484                         << "};\n"
485                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
486                         << "{\n"
487                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
488                         << "};\n"
489                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
490                         << "{\n"
491                         << "  uint data2[];\n"
492                         << "};\n"
493                         << "\n"
494                         << "void main (void)\n"
495                         << "{\n"
496                         << "  uvec4 mask = subgroupBallot(true);\n"
497                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
498                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
499                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
500                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
501                         << "  {\n"
502                         << "    result[gl_PrimitiveID] = (op == data1[id]) ? 1 : 0;\n"
503                         << "  }\n"
504                         << "  else\n"
505                         << "  {\n"
506                         << "    result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
507                         << "  }\n"
508                         << "}\n";
509
510                 programCollection.glslSources.add("tesc")
511                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
512         }
513         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
514         {
515                 programCollection.glslSources.add("vert")
516                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
517
518                 programCollection.glslSources.add("tesc")
519                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
520
521                 std::ostringstream src;
522
523                 src << "#version 450\n";
524
525                 switch (caseDef.opType)
526                 {
527                         case OPTYPE_SHUFFLE:
528                         case OPTYPE_SHUFFLE_XOR:
529                                 src << "#extension GL_KHR_shader_subgroup_shuffle: enable\n";
530                                 break;
531                         default:
532                                 src << "#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
533                                 break;
534                 }
535
536                 src     << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
537                         << "layout(isolines) in;\n"
538                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
539                         << "{\n"
540                         << "  uint result[];\n"
541                         << "};\n"
542                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
543                         << "{\n"
544                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
545                         << "};\n"
546                         << "layout(set = 0, binding = 2, std430) buffer Buffer3\n"
547                         << "{\n"
548                         << "  uint data2[];\n"
549                         << "};\n"
550                         << "\n"
551                         << "void main (void)\n"
552                         << "{\n"
553                         << "  uvec4 mask = subgroupBallot(true);\n"
554                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
555                         << getOpTypeName(caseDef.opType) << "(data1[gl_SubgroupInvocationID], data2[gl_SubgroupInvocationID]);\n"
556                         << "  uint id = " << idTable[caseDef.opType] << ";\n"
557                         << "  if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
558                         << "  {\n"
559                         << "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data1[id]) ? 1 : 0;\n"
560                         << "  }\n"
561                         << "  else\n"
562                         << "  {\n"
563                         << "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
564                         << "  }\n"
565                         << "}\n";
566
567                 programCollection.glslSources.add("tese")
568                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
569         }
570         else
571         {
572                 DE_FATAL("Unsupported shader stage");
573         }
574 }
575
576 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
577 {
578         if (!subgroups::isSubgroupSupported(context))
579                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
580
581         if (!subgroups::areSubgroupOperationsSupportedForStage(
582                                 context, caseDef.shaderStage))
583         {
584                 if (subgroups::areSubgroupOperationsRequiredForStage(
585                                         caseDef.shaderStage))
586                 {
587                         return tcu::TestStatus::fail(
588                                            "Shader stage " +
589                                            subgroups::getShaderStageName(caseDef.shaderStage) +
590                                            " is required to support subgroup operations!");
591                 }
592                 else
593                 {
594                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
595                 }
596         }
597
598         switch (caseDef.opType)
599         {
600                 case OPTYPE_SHUFFLE:
601                 case OPTYPE_SHUFFLE_XOR:
602                         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_BIT))
603                         {
604                                 TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle operations");
605                         }
606                         break;
607                 default:
608                         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT))
609                         {
610                                 TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle relative operations");
611                         }
612                         break;
613         }
614
615         if (subgroups::isDoubleFormat(caseDef.format) &&
616                         !subgroups::isDoubleSupportedForDevice(context))
617         {
618                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
619         }
620
621         //Tests which don't use the SSBO
622         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
623         {
624                 subgroups::SSBOData inputData[2];
625                 inputData[0].format = caseDef.format;
626                 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
627                 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
628
629                 inputData[1].format = VK_FORMAT_R32_UINT;
630                 inputData[1].numElements = inputData[0].numElements;
631                 inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
632
633                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT,  inputData, 2, checkVertexPipelineStages);
634         }
635
636         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
637                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
638         {
639                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
640                 {
641                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
642                 }
643         }
644
645         subgroups::SSBOData inputData[2];
646         inputData[0].format = caseDef.format;
647         inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
648         inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
649
650         inputData[1].format = VK_FORMAT_R32_UINT;
651         inputData[1].numElements = inputData[0].numElements;
652         inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
653
654         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
655         {
656                 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
657                                                                                    inputData, 2, checkFragment);
658         }
659         else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
660         {
661                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
662                                                                                   inputData, 2, checkCompute);
663         }
664         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
665         {
666                 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
667                                                                                  inputData, 2, checkVertexPipelineStages);
668         }
669         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
670         {
671                 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
672                                                                                    inputData, 2, checkVertexPipelineStages);
673         }
674         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
675         {
676                 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
677                                 inputData, 2, checkVertexPipelineStages);
678         }
679         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
680         {
681                 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
682                                 inputData, 2, checkVertexPipelineStages);
683         }
684         else
685         {
686                 TCU_THROW(InternalError, "Unhandled shader stage");
687         }
688 }
689 }
690
691 namespace vkt
692 {
693 namespace subgroups
694 {
695 tcu::TestCaseGroup* createSubgroupsShuffleTests(tcu::TestContext& testCtx)
696 {
697         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
698                         testCtx, "shuffle", "Subgroup shuffle category tests"));
699
700         const VkShaderStageFlags stages[] =
701         {
702                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
703                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
704                 VK_SHADER_STAGE_GEOMETRY_BIT,
705                 VK_SHADER_STAGE_VERTEX_BIT,
706                 VK_SHADER_STAGE_FRAGMENT_BIT,
707                 VK_SHADER_STAGE_COMPUTE_BIT
708         };
709
710         const VkFormat formats[] =
711         {
712                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
713                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
714                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
715                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
716                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
717                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
718                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
719                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
720                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
721         };
722
723         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
724         {
725                 const VkShaderStageFlags stage = stages[stageIndex];
726
727                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
728                 {
729                         const VkFormat format = formats[formatIndex];
730
731                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
732                         {
733                                 CaseDefinition caseDef = {opTypeIndex, stage, format, false};
734
735                                 std::ostringstream name;
736
737                                 std::string op = getOpTypeName(opTypeIndex);
738
739                                 name << de::toLower(op)
740                                          << "_" << subgroups::getFormatNameForGLSL(format)
741                                          << "_" << getShaderStageName(stage);
742
743                                 addFunctionCaseWithPrograms(group.get(), name.str(),
744                                                                                         "", initPrograms, test, caseDef);
745
746                                 if (VK_SHADER_STAGE_VERTEX_BIT == stage )
747                                 {
748                                         caseDef.noSSBO = true;
749                                         addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
750                                                                                                 initFrameBufferPrograms, test, caseDef);
751                                 }
752                         }
753                 }
754         }
755
756         return group.release();
757 }
758
759 } // subgroups
760 } // vkt