1abe2564eab44a0af515dc973544d784dd9bcee6
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsQuadTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsQuadTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_QUAD_BROADCAST = 0,
41         OPTYPE_QUAD_SWAP_HORIZONTAL,
42         OPTYPE_QUAD_SWAP_VERTICAL,
43         OPTYPE_QUAD_SWAP_DIAGONAL,
44         OPTYPE_LAST
45 };
46
47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48                                                                           deUint32 width, deUint32)
49 {
50         return vkt::subgroups::check(datas, width, 1);
51 }
52
53 static bool checkCompute(std::vector<const void*> datas,
54                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
55                                                  deUint32)
56 {
57         return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
58 }
59
60 std::string getOpTypeName(int opType)
61 {
62         switch (opType)
63         {
64                 default:
65                         DE_FATAL("Unsupported op type");
66                         return "";
67                 case OPTYPE_QUAD_BROADCAST:
68                         return "subgroupQuadBroadcast";
69                 case OPTYPE_QUAD_SWAP_HORIZONTAL:
70                         return "subgroupQuadSwapHorizontal";
71                 case OPTYPE_QUAD_SWAP_VERTICAL:
72                         return "subgroupQuadSwapVertical";
73                 case OPTYPE_QUAD_SWAP_DIAGONAL:
74                         return "subgroupQuadSwapDiagonal";
75         }
76 }
77
78 struct CaseDefinition
79 {
80         int                                     opType;
81         VkShaderStageFlags      shaderStage;
82         VkFormat                        format;
83         int                                     direction;
84 };
85
86 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
87 {
88         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
89         std::string                     swapTable[OPTYPE_LAST];
90
91         subgroups::setFragmentShaderFrameBuffer(programCollection);
92
93         if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
94                 subgroups::setVertexShaderFrameBuffer(programCollection);
95
96         swapTable[OPTYPE_QUAD_BROADCAST] = "";
97         swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
98         swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
99         swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
100
101         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
102         {
103                 std::ostringstream      vertexSrc;
104                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
105                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
106                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
107                         << "layout(location = 0) in highp vec4 in_position;\n"
108                         << "layout(location = 0) out float result;\n"
109                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
110                         << "{\n"
111                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
112                         << "};\n"
113                         << "\n"
114                         << "void main (void)\n"
115                         << "{\n"
116                         << "  uvec4 mask = subgroupBallot(true);\n"
117                         << swapTable[caseDef.opType];
118
119                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
120                 {
121                         vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
122                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
123                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
124                 }
125                 else
126                 {
127                         vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
128                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
129                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
130                 }
131
132                 vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
133                         << "  {\n"
134                         << "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
135                         << "  }\n"
136                         << "  else\n"
137                         << "  {\n"
138                         << "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
139                         << "  }\n"
140                         << "  gl_Position = in_position;\n"
141                         << "  gl_PointSize = 1.0f;\n"
142                         << "}\n";
143                 programCollection.glslSources.add("vert")
144                         << glu::VertexSource(vertexSrc.str()) << buildOptions;
145         }
146         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
147         {
148                 std::ostringstream geometry;
149
150                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
151                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
152                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
153                         << "layout(points) in;\n"
154                         << "layout(points, max_vertices = 1) out;\n"
155                         << "layout(location = 0) out float out_color;\n"
156
157                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
158                         << "{\n"
159                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160                         << "};\n"
161                         << "\n"
162                         << "void main (void)\n"
163                         << "{\n"
164                         << "  uvec4 mask = subgroupBallot(true);\n"
165                         << swapTable[caseDef.opType];
166
167                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
168                 {
169                         geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
171                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
172                 }
173                 else
174                 {
175                         geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
178                 }
179
180                 geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
181                         << "  {\n"
182                         << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
183                         << "  }\n"
184                         << "  else\n"
185                         << "  {\n"
186                         << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
187                         << "  }\n"
188                         << "  gl_Position = gl_in[0].gl_Position;\n"
189                         << "  EmitVertex();\n"
190                         << "  EndPrimitive();\n"
191                         << "}\n";
192
193                 programCollection.glslSources.add("geometry")
194                         << glu::GeometrySource(geometry.str()) << buildOptions;
195         }
196         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
197         {
198                 std::ostringstream controlSource;
199
200                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
201                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
202                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
203                         << "layout(vertices = 2) out;\n"
204                         << "layout(location = 0) out float out_color[];\n"
205                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
206                         << "{\n"
207                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
208                         << "};\n"
209                         << "\n"
210                         << "void main (void)\n"
211                         << "{\n"
212                         << "  if (gl_InvocationID == 0)\n"
213                         <<"  {\n"
214                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
215                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
216                         << "  }\n"
217                         << "  uvec4 mask = subgroupBallot(true);\n"
218                         << swapTable[caseDef.opType];
219
220                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
221                 {
222                         controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
223                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
224                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
225                 }
226                 else
227                 {
228                         controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
229                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
230                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
231                 }
232
233                 controlSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
234                         << "  {\n"
235                         << "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
236                         << "  }\n"
237                         << "  else\n"
238                         << "  {\n"
239                         << "    out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
240                         << "  }\n"
241                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
242                         << "}\n";
243
244                 programCollection.glslSources.add("tesc")
245                         << glu::TessellationControlSource(controlSource.str()) << buildOptions;
246                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
247         }
248         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
249         {
250                 ostringstream evaluationSource;
251                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
253                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
254                         << "layout(isolines, equal_spacing, ccw ) in;\n"
255                         << "layout(location = 0) out float out_color;\n"
256                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
257                         << "{\n"
258                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
259                         << "};\n"
260                         << "\n"
261                         << "void main (void)\n"
262                         << "{\n"
263                         << "  uvec4 mask = subgroupBallot(true);\n"
264                         << swapTable[caseDef.opType];
265
266                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
267                 {
268                         evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
269                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
270                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
271                 }
272                 else
273                 {
274                         evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
275                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
276                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
277                 }
278
279                 evaluationSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
280                         << "  {\n"
281                         << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
282                         << "  }\n"
283                         << "  else\n"
284                         << "  {\n"
285                         << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
286                         << "  }\n"
287                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
288                         << "}\n";
289
290                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
291                 programCollection.glslSources.add("tese")
292                                 << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
293         }
294         else
295         {
296                 DE_FATAL("Unsupported shader stage");
297         }
298 }
299
300 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
301 {
302         std::string swapTable[OPTYPE_LAST];
303         swapTable[OPTYPE_QUAD_BROADCAST] = "";
304         swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
305         swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
306         swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
307
308         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
309         {
310                 std::ostringstream src;
311
312                 src << "#version 450\n"
313                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
314                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
315                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
316                         "local_size_z_id = 2) in;\n"
317                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
318                         << "{\n"
319                         << "  uint result[];\n"
320                         << "};\n"
321                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
322                         << "{\n"
323                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
324                         << "};\n"
325                         << "\n"
326                         << "void main (void)\n"
327                         << "{\n"
328                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
329                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
330                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
331                         "gl_GlobalInvocationID.x;\n"
332                         << "  uvec4 mask = subgroupBallot(true);\n"
333                         << swapTable[caseDef.opType];
334
335
336                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
337                 {
338                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
339                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
340                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
341                 }
342                 else
343                 {
344                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
345                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
346                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
347                 }
348
349                 src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
350                         << "  {\n"
351                         << "    result[offset] = (op == data[otherID]) ? 1 : 0;\n"
352                         << "  }\n"
353                         << "  else\n"
354                         << "  {\n"
355                         << "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
356                         << "  }\n"
357                         << "}\n";
358
359                 programCollection.glslSources.add("comp")
360                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
361         }
362         else
363         {
364                 std::ostringstream src;
365                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
366                 {
367                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
368                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
369                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
370                 }
371                 else
372                 {
373                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
374                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
375                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
376                 }
377                 const string sourceType = src.str();
378
379                 {
380                         const string vertex =
381                                 "#version 450\n"
382                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
383                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
384                                 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
385                                 "{\n"
386                                 "  uint result[];\n"
387                                 "};\n"
388                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
389                                 "{\n"
390                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
391                                 "};\n"
392                                 "\n"
393                                 "void main (void)\n"
394                                 "{\n"
395                                 "  uvec4 mask = subgroupBallot(true);\n"
396                                 + swapTable[caseDef.opType]
397                                 + sourceType +
398                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
399                                 "  {\n"
400                                 "    result[gl_VertexIndex] = (op == data[otherID]) ? 1 : 0;\n"
401                                 "  }\n"
402                                 "  else\n"
403                                 "  {\n"
404                                 "    result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
405                                 "  }\n"
406                                 "  float pixelSize = 2.0f/1024.0f;\n"
407                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
408                                 "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
409                                 "}\n";
410                         programCollection.glslSources.add("vert")
411                                 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
412                 }
413
414                 {
415                         const string tesc =
416                                 "#version 450\n"
417                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
418                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
419                                 "layout(vertices=1) out;\n"
420                                 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
421                                 "{\n"
422                                 "  uint result[];\n"
423                                 "};\n"
424                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
425                                 "{\n"
426                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
427                                 "};\n"
428                                 "\n"
429                                 "void main (void)\n"
430                                 "{\n"
431                                 "  uvec4 mask = subgroupBallot(true);\n"
432                                 + swapTable[caseDef.opType]
433                                 + sourceType +
434                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
435                                 "  {\n"
436                                 "    result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
437                                 "  }\n"
438                                 "  else\n"
439                                 "  {\n"
440                                 "    result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
441                                 "  }\n"
442                                 "  if (gl_InvocationID == 0)\n"
443                                 "  {\n"
444                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
445                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
446                                 "  }\n"
447                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
448                                 "}\n";
449                         programCollection.glslSources.add("tesc")
450                                         << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
451                 }
452
453                 {
454                         const string tese =
455                                 "#version 450\n"
456                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
457                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
458                                 "layout(isolines) in;\n"
459                                 "layout(set = 0, binding = 2, std430)  buffer Buffer1\n"
460                                 "{\n"
461                                 "  uint result[];\n"
462                                 "};\n"
463                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
464                                 "{\n"
465                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
466                                 "};\n"
467                                 "\n"
468                                 "void main (void)\n"
469                                 "{\n"
470                                 "  uvec4 mask = subgroupBallot(true);\n"
471                                 + swapTable[caseDef.opType]
472                                 + sourceType +
473                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
474                                 "  {\n"
475                                 "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
476                                 "  }\n"
477                                 "  else\n"
478                                 "  {\n"
479                                 "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
480                                 "  }\n"
481                                 "  float pixelSize = 2.0f/1024.0f;\n"
482                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
483                                 "}\n";
484                         programCollection.glslSources.add("tese")
485                                         << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
486                 }
487
488                 {
489                         const string geometry =
490                                 "#version 450\n"
491                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
492                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
493                                 "layout(${TOPOLOGY}) in;\n"
494                                 "layout(points, max_vertices = 1) out;\n"
495                                 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
496                                 "{\n"
497                                 "  uint result[];\n"
498                                 "};\n"
499                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
500                                 "{\n"
501                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
502                                 "};\n"
503                                 "\n"
504                                 "void main (void)\n"
505                                 "{\n"
506                                 "  uvec4 mask = subgroupBallot(true);\n"
507                                 + swapTable[caseDef.opType]
508                                 + sourceType +
509                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
510                                 "  {\n"
511                                 "    result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
512                                 "  }\n"
513                                 "  else\n"
514                                 "  {\n"
515                                 "    result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
516                                 "  }\n"
517                                 "  gl_Position = gl_in[0].gl_Position;\n"
518                                 "  EmitVertex();\n"
519                                 "  EndPrimitive();\n"
520                                 "}\n";
521                         subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
522                                                                                                           programCollection.glslSources);
523                 }
524
525                 {
526                         const string fragment =
527                                 "#version 450\n"
528                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
529                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
530                                 "layout(location = 0) out uint result;\n"
531                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
532                                 "{\n"
533                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
534                                 "};\n"
535                                 "void main (void)\n"
536                                 "{\n"
537                                 "  uvec4 mask = subgroupBallot(true);\n"
538                                 + swapTable[caseDef.opType]
539                                 + sourceType +
540                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
541                                 "  {\n"
542                                 "    result = (op == data[otherID]) ? 1 : 0;\n"
543                                 "  }\n"
544                                 "  else\n"
545                                 "  {\n"
546                                 "    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
547                                 "  }\n"
548                                 "}\n";
549                         programCollection.glslSources.add("fragment")
550                                 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
551                 }
552                 subgroups::addNoSubgroupShader(programCollection);
553         }
554 }
555
556 void supportedCheck (Context& context, CaseDefinition caseDef)
557 {
558         if (!subgroups::isSubgroupSupported(context))
559                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
560
561         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
562                 TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
563
564
565         if (subgroups::isDoubleFormat(caseDef.format) &&
566                         !subgroups::isDoubleSupportedForDevice(context))
567         {
568                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
569         }
570 }
571
572 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
573 {
574         if (!subgroups::areSubgroupOperationsSupportedForStage(
575                                 context, caseDef.shaderStage))
576         {
577                 if (subgroups::areSubgroupOperationsRequiredForStage(
578                                         caseDef.shaderStage))
579                 {
580                         return tcu::TestStatus::fail(
581                                            "Shader stage " +
582                                            subgroups::getShaderStageName(caseDef.shaderStage) +
583                                            " is required to support subgroup operations!");
584                 }
585                 else
586                 {
587                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
588                 }
589         }
590
591         subgroups::SSBOData inputData;
592         inputData.format = caseDef.format;
593         inputData.numElements = subgroups::maxSupportedSubgroupSize();
594         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;;
595
596         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
597                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
598         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
599                 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
600         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
601                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
602         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
603                 return subgroups::makeTessellationEvaluationFrameBufferTest(context,  VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
604         else
605                 TCU_THROW(InternalError, "Unhandled shader stage");
606 }
607
608
609 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
610 {
611         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
612         {
613                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
614                 {
615                         return tcu::TestStatus::fail(
616                                            "Shader stage " +
617                                            subgroups::getShaderStageName(caseDef.shaderStage) +
618                                            " is required to support subgroup operations!");
619                 }
620                 subgroups::SSBOData inputData;
621                 inputData.format = caseDef.format;
622                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
623                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
624
625                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
626         }
627         else
628         {
629                 VkPhysicalDeviceSubgroupProperties subgroupProperties;
630                 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
631                 subgroupProperties.pNext = DE_NULL;
632
633                 VkPhysicalDeviceProperties2 properties;
634                 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
635                 properties.pNext = &subgroupProperties;
636
637                 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
638
639                 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
640
641                 if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
642                 {
643                         if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
644                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
645                         else
646                                 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
647                 }
648
649                 if ((VkShaderStageFlagBits)0u == stages)
650                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
651
652                 subgroups::SSBOData inputData;
653                 inputData.format                        = caseDef.format;
654                 inputData.numElements           = subgroups::maxSupportedSubgroupSize();
655                 inputData.initializeType        = subgroups::SSBOData::InitializeNonZero;
656                 inputData.binding                       = 4u;
657                 inputData.stages                        = stages;
658
659                 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
660         }
661 }
662 }
663
664 namespace vkt
665 {
666 namespace subgroups
667 {
668 tcu::TestCaseGroup* createSubgroupsQuadTests(tcu::TestContext& testCtx)
669 {
670         de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
671                 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
672         de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
673                 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
674         de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
675                 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
676
677         const VkFormat formats[] =
678         {
679                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
680                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
681                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
682                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
683                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
684                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
685                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
686                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
687                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
688         };
689
690         const VkShaderStageFlags stages[] =
691         {
692                 VK_SHADER_STAGE_VERTEX_BIT,
693                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
694                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
695                 VK_SHADER_STAGE_GEOMETRY_BIT,
696         };
697
698         for (int direction = 0; direction < 4; ++direction)
699         {
700                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
701                 {
702                         const VkFormat format = formats[formatIndex];
703
704                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
705                         {
706                                 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
707                                 std::ostringstream name;
708                                 name << de::toLower(op);
709
710                                 if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
711                                 {
712                                         name << "_" << direction;
713                                 }
714                                 else
715                                 {
716                                         if (0 != direction)
717                                         {
718                                                 // We don't need direction for swap operations.
719                                                 continue;
720                                         }
721                                 }
722
723                                 name << "_" << subgroups::getFormatNameForGLSL(format);
724
725                                 {
726                                         const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format, direction};
727                                         addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
728                                 }
729
730                                 {
731                                         const CaseDefinition caseDef =
732                                         {
733                                                 opTypeIndex,
734                                                 VK_SHADER_STAGE_ALL_GRAPHICS,
735                                                 format,
736                                                 direction
737                                         };
738                                         addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
739                                 }
740                                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
741                                 {
742                                         const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
743                                         addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
744                                                                                                 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
745                                 }
746
747                         }
748                 }
749         }
750
751         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
752                 testCtx, "quad", "Subgroup quad category tests"));
753
754         group->addChild(graphicGroup.release());
755         group->addChild(computeGroup.release());
756         group->addChild(framebufferGroup.release());
757
758         return group.release();
759 }
760 } // subgroups
761 } // vkt