f201dd24a18ada6f97c2037d8fba0195b4dfe113
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsVoteTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsVoteTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_ALL = 0,
41         OPTYPE_ANY,
42         OPTYPE_ALLEQUAL,
43         OPTYPE_LAST
44 };
45
46 static bool checkVertexPipelineStages(std::vector<const void*> datas,
47                                                                           deUint32 width, deUint32)
48 {
49         return vkt::subgroups::check(datas, width, 0x1F);
50 }
51
52 static bool checkFragmentPipelineStages(std::vector<const void*> datas,
53                                                                           deUint32 width, deUint32 height, deUint32)
54 {
55         const deUint32* data =
56                 reinterpret_cast<const deUint32*>(datas[0]);
57         for (deUint32 x = 0u; x < width; ++x)
58         {
59                 for (deUint32 y = 0u; y < height; ++y)
60                 {
61                         const deUint32 ndx = (x * height + y);
62                         deUint32 val = data[ndx] & 0x1F;
63
64                         if (data[ndx] & 0x40) //Helper fragment shader invocation was executed
65                         {
66                                 if(val != 0x1F)
67                                         return false;
68                         }
69                         else //Helper fragment shader invocation was not executed yet
70                         {
71                                 if (val != 0x1E)
72                                         return false;
73                         }
74                 }
75         }
76         return true;
77 }
78
79 static bool checkCompute(std::vector<const void*> datas,
80                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
81                                                  deUint32)
82 {
83         return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x1F);
84 }
85
86 std::string getOpTypeName(int opType)
87 {
88         switch (opType)
89         {
90                 default:
91                         DE_FATAL("Unsupported op type");
92                         return "";
93                 case OPTYPE_ALL:
94                         return "subgroupAll";
95                 case OPTYPE_ANY:
96                         return "subgroupAny";
97                 case OPTYPE_ALLEQUAL:
98                         return "subgroupAllEqual";
99         }
100 }
101
102 struct CaseDefinition
103 {
104         int                                     opType;
105         VkShaderStageFlags      shaderStage;
106         VkFormat                        format;
107 };
108
109 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
110 {
111         const vk::ShaderBuildOptions buildOptions       (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
112         const bool formatIsBoolean =
113                 VK_FORMAT_R8_USCALED == caseDef.format || VK_FORMAT_R8G8_USCALED == caseDef.format || VK_FORMAT_R8G8B8_USCALED == caseDef.format || VK_FORMAT_R8G8B8A8_USCALED == caseDef.format;
114
115         if (VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage)
116                 subgroups::setFragmentShaderFrameBuffer(programCollection);
117
118         if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
119         {
120                 const string vertex     = "#version 450\n"
121                         "void main (void)\n"
122                         "{\n"
123                         "  vec2 uv = vec2(float(gl_VertexIndex & 1), float((gl_VertexIndex >> 1) & 1));\n"
124                         "  gl_Position = vec4(uv * 4.0f -2.0f, 0.0f, 1.0f);\n"
125                         "  gl_PointSize = 1.0f;\n"
126                         "}\n";
127                 programCollection.glslSources.add("vert") << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
128         }
129         else if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
130                 subgroups::setVertexShaderFrameBuffer(programCollection);
131
132         const string source =
133                 (OPTYPE_ALL == caseDef.opType) ?
134                         "  result = " + getOpTypeName(caseDef.opType) +
135                         "(true) ? 0x1 : 0;\n"
136                         "  result |= " + getOpTypeName(caseDef.opType) +
137                         "(false) ? 0 : 0x1A;\n"
138                         "  result |= 0x4;\n"
139                 : (OPTYPE_ANY == caseDef.opType) ?
140                                 "  result = " + getOpTypeName(caseDef.opType) +
141                                 "(true) ? 0x1 : 0;\n"
142                                 "  result |= " + getOpTypeName(caseDef.opType) +
143                                 "(false) ? 0 : 0x1A;\n"
144                                 "  result |= 0x4;\n"
145                 : (OPTYPE_ALLEQUAL == caseDef.opType) ?
146                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
147                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect())\n;" : "(12.0 * float(data[gl_SubgroupInvocationID]) + gl_SubgroupInvocationID);\n") +
148                                 "  result = " + getOpTypeName(caseDef.opType) + "("
149                                 + subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
150                                 "  result |= " + getOpTypeName(caseDef.opType) +
151                                 "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
152                                 "  result |= " + getOpTypeName(caseDef.opType) +
153                                 "(data[0]) ? 0x4 : 0;\n"
154                                 "  result |= " + getOpTypeName(caseDef.opType) +
155                                 "(valueEqual) ? 0x8 : 0x0;\n"
156                                 "  result |= " + getOpTypeName(caseDef.opType) +
157                                 "(valueNoEqual) ? 0x0 : 0x10;\n"
158                                 "  if (subgroupElect()) result |= 0x2 | 0x10;\n"
159                 : "";
160
161         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
162         {
163                 std::ostringstream vertexSrc;
164                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
165                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
166                         << "layout(location = 0) out vec4 out_color;\n"
167                         << "layout(location = 0) in highp vec4 in_position;\n"
168                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
169                         << "{\n"
170                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
171                         << "};\n"
172                         << "\n"
173                         << "void main (void)\n"
174                         << "{\n"
175                         << "  uint result;\n"
176                         << source
177                         << "  out_color.r = float(result);\n"
178                         << "  gl_Position = in_position;\n"
179                         << "  gl_PointSize = 1.0f;\n"
180                         << "}\n";
181
182                 programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
183         }
184         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
185         {
186                 std::ostringstream geometry;
187
188                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
189                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
190                         << "layout(points) in;\n"
191                         << "layout(points, max_vertices = 1) out;\n"
192                         << "layout(location = 0) out float out_color;\n"
193                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
194                         << "{\n"
195                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
196                         << "};\n"
197                         << "\n"
198                         << "void main (void)\n"
199                         << "{\n"
200                         << "  uint result;\n"
201                         << source
202                         << "  out_color = float(result);\n"
203                         << "  gl_Position = gl_in[0].gl_Position;\n"
204                         << "  EmitVertex();\n"
205                         << "  EndPrimitive();\n"
206                         << "}\n";
207
208                 programCollection.glslSources.add("geometry")
209                         << glu::GeometrySource(geometry.str()) << buildOptions;
210         }
211         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
212         {
213                 std::ostringstream controlSource;
214                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
215                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
216                         << "layout(vertices = 2) out;\n"
217                         << "layout(location = 0) out float out_color[];\n"
218                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
219                         << "{\n"
220                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
221                         << "};\n"
222                         << "\n"
223                         << "void main (void)\n"
224                         << "{\n"
225                         << "  uint result;\n"
226                         << "  if (gl_InvocationID == 0)\n"
227                         <<"  {\n"
228                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
229                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
230                         << "  }\n"
231                         << source
232                         << "  out_color[gl_InvocationID] = float(result);"
233                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
234                         << "}\n";
235
236                 programCollection.glslSources.add("tesc")
237                         << glu::TessellationControlSource(controlSource.str()) << buildOptions;
238                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
239         }
240         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
241         {
242                 std::ostringstream evaluationSource;
243                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
244                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
245                         << "#extension GL_EXT_tessellation_shader : require\n"
246                         << "layout(isolines, equal_spacing, ccw ) in;\n"
247                         << "layout(location = 0) out float out_color;\n"
248                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
249                         << "{\n"
250                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
251                         << "};\n"
252                         << "\n"
253                         << "void main (void)\n"
254                         << "{\n"
255                         << "  uint result;\n"
256                         << "  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n"
257                         << source
258                         << "  out_color = float(result);\n"
259                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
260                         << "}\n";
261
262                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
263                 programCollection.glslSources.add("tese")
264                                 << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
265         }
266         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
267         {
268                 const string sourceFragment =
269                 (OPTYPE_ALL == caseDef.opType) ?
270                         "  result |= " + getOpTypeName(caseDef.opType) +
271                         "(!gl_HelperInvocation) ? 0x0 : 0x1;\n"
272                         "  result |= " + getOpTypeName(caseDef.opType) +
273                         "(false) ? 0 : 0x1A;\n"
274                         "  result |= 0x4;\n"
275                 : (OPTYPE_ANY == caseDef.opType) ?
276                                 "  result |= " + getOpTypeName(caseDef.opType) +
277                                 "(gl_HelperInvocation) ? 0x1 : 0x0;\n"
278                                 "  result |= " + getOpTypeName(caseDef.opType) +
279                                 "(false) ? 0 : 0x1A;\n"
280                                 "  result |= 0x4;\n"
281                 : (OPTYPE_ALLEQUAL == caseDef.opType) ?
282                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
283                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + int(gl_FragCoord.x*gl_SubgroupInvocationID));\n") +
284                                 "  result |= " + getOpTypeName(caseDef.opType) + "("
285                                 + subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x10 : 0;\n"
286                                 "  result |= " + getOpTypeName(caseDef.opType) +
287                                 "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
288                                 "  result |= " + getOpTypeName(caseDef.opType) +
289                                 "(data[0]) ? 0x4 : 0;\n"
290                                 "  result |= " + getOpTypeName(caseDef.opType) +
291                                 "(valueEqual) ? 0x8 : 0x0;\n"
292                                 "  result |= " + getOpTypeName(caseDef.opType) +
293                                 "(gl_HelperInvocation) ? 0x0 : 0x1;\n"
294                                 "  if (subgroupElect()) result |= 0x2 | 0x10;\n"
295                 : "";
296
297                 std::ostringstream fragmentSource;
298                 fragmentSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
299                 << "#extension GL_KHR_shader_subgroup_vote: enable\n"
300                 << "layout(location = 0) out uint out_color;\n"
301                 << "layout(set = 0, binding = 0) uniform Buffer1\n"
302                 << "{\n"
303                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
304                 << "};\n"
305                 << ""
306                 << "void main()\n"
307                 << "{\n"
308                 << "  uint result = 0u;\n"
309                 << "  if (dFdx(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) - dFdy(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) > 0.0f)\n"
310                 << "  {\n"
311                 << "    result |= 0x20;\n" // to be sure that compiler doesn't remove dFdx and dFdy executions
312                 << "  }\n"
313                 << "  bool helper = subgroupAny(gl_HelperInvocation);\n"
314                 << "  if (helper)\n"
315                 << "  {\n"
316                 << "    result |= 0x40;\n"
317                 << "  }\n"
318                 << sourceFragment
319                 << "  out_color = result;\n"
320                 << "}\n";
321
322                 programCollection.glslSources.add("fragment")
323                         << glu::FragmentSource(fragmentSource.str())<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
324         }
325         else
326         {
327                 DE_FATAL("Unsupported shader stage");
328         }
329 }
330
331 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
332 {
333         const bool formatIsBoolean =
334                 VK_FORMAT_R8_USCALED == caseDef.format || VK_FORMAT_R8G8_USCALED == caseDef.format || VK_FORMAT_R8G8B8_USCALED == caseDef.format || VK_FORMAT_R8G8B8A8_USCALED == caseDef.format;
335         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
336         {
337                 std::ostringstream src;
338
339                 src << "#version 450\n"
340                         << "#extension GL_KHR_shader_subgroup_vote: enable\n"
341                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
342                         "local_size_z_id = 2) in;\n"
343                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
344                         << "{\n"
345                         << "  uint result[];\n"
346                         << "};\n"
347                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
348                         << "{\n"
349                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
350                         << "};\n"
351                         << "\n"
352                         << "void main (void)\n"
353                         << "{\n"
354                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
355                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
356                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
357                         "gl_GlobalInvocationID.x;\n";
358                 if (OPTYPE_ALL == caseDef.opType)
359                 {
360                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
361                                 << "(true) ? 0x1 : 0;\n"
362                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
363                                 << "(false) ? 0 : 0x1A;\n"
364                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
365                                 << "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
366                 }
367                 else if (OPTYPE_ANY == caseDef.opType)
368                 {
369                         src << "  result[offset] = " << getOpTypeName(caseDef.opType)
370                                 << "(true) ? 0x1 : 0;\n"
371                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
372                                 << "(false) ? 0 : 0x1A;\n"
373                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
374                                 << "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
375                 }
376
377                 else if (OPTYPE_ALLEQUAL == caseDef.opType)
378                 {
379                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) <<" valueEqual = " << subgroups::getFormatNameForGLSL(caseDef.format) << "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n"
380                                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) <<" valueNoEqual = " << subgroups::getFormatNameForGLSL(caseDef.format) << (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + offset);\n")
381                                 <<"  result[offset] = " << getOpTypeName(caseDef.opType) << "("
382                                 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0x0;\n"
383                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
384                                 << "(gl_SubgroupInvocationID) ? 0x0 : 0x2;\n"
385                                 << "  result[offset] |= " << getOpTypeName(caseDef.opType)
386                                 << "(data[0]) ? 0x4 : 0x0;\n"
387                                 << "  result[offset] |= "<< getOpTypeName(caseDef.opType)
388                                 << "(valueEqual) ? 0x8 : 0x0;\n"
389                                 << "  result[offset] |= "<< getOpTypeName(caseDef.opType)
390                                 << "(valueNoEqual) ? 0x0 : 0x10;\n"
391                                 << "  if (subgroupElect()) result[offset] |= 0x2 | 0x10;\n";
392                 }
393
394                 src << "}\n";
395
396                 programCollection.glslSources.add("comp")
397                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
398         }
399         else
400         {
401                 const string source =
402                 (OPTYPE_ALL == caseDef.opType) ?
403                         "  result[offset] = " + getOpTypeName(caseDef.opType) +
404                         "(true) ? 0x1 : 0;\n"
405                         "  result[offset] |= " + getOpTypeName(caseDef.opType) +
406                         "(false) ? 0 : 0x1A;\n"
407                         "  result[offset] |= 0x4;\n"
408                 : (OPTYPE_ANY == caseDef.opType) ?
409                                 "  result[offset] = " + getOpTypeName(caseDef.opType) +
410                                 "(true) ? 0x1 : 0;\n"
411                                 "  result[offset] |= " + getOpTypeName(caseDef.opType) +
412                                 "(false) ? 0 : 0x1A;\n"
413                                 "  result[offset] |= 0x4;\n"
414                 : (OPTYPE_ALLEQUAL == caseDef.opType) ?
415                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
416                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + gl_SubgroupInvocationID);\n") +
417                                 "  result[offset] = " + getOpTypeName(caseDef.opType) + "("
418                                 + subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
419                                 "  result[offset] |= " + getOpTypeName(caseDef.opType) +
420                                 "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
421                                 "  result[offset] |= " + getOpTypeName(caseDef.opType) +
422                                 "(data[0]) ? 0x4 : 0;\n"
423                                 "  result[offset] |= " + getOpTypeName(caseDef.opType) +
424                                 "(valueEqual) ? 0x8 : 0x0;\n"
425                                 "  result[offset] |= " + getOpTypeName(caseDef.opType) +
426                                 "(valueNoEqual) ? 0x0 : 0x10;\n"
427                                 "  if (subgroupElect()) result[offset] |= 0x2 | 0x10;\n"
428                 : "";
429
430                 const string formatString = subgroups::getFormatNameForGLSL(caseDef.format);
431
432                 {
433                         const string vertex =
434                                 "#version 450\n"
435                                 "#extension GL_KHR_shader_subgroup_vote: enable\n"
436                                 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
437                                 "{\n"
438                                 "  uint result[];\n"
439                                 "};\n"
440                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
441                                 "{\n"
442                                 "  " + formatString + " data[];\n"
443                                 "};\n"
444                                 "\n"
445                                 "void main (void)\n"
446                                 "{\n"
447                                 "  highp uint offset = gl_VertexIndex;\n"
448                                 + source +
449                                 "  float pixelSize = 2.0f/1024.0f;\n"
450                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
451                                 "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
452                                 "  gl_PointSize = 1.0f;\n"
453                                 "}\n";
454                         programCollection.glslSources.add("vert")
455                                 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
456                 }
457
458                 {
459                         const string tesc =
460                                 "#version 450\n"
461                                 "#extension GL_KHR_shader_subgroup_vote: enable\n"
462                                 "layout(vertices=1) out;\n"
463                                 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
464                                 "{\n"
465                                 "  uint result[];\n"
466                                 "};\n"
467                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
468                                 "{\n"
469                                 "  " + formatString + " data[];\n"
470                                 "};\n"
471                                 "\n"
472                                 "void main (void)\n"
473                                 "{\n"
474                                 "  highp uint offset = gl_PrimitiveID;\n"
475                                 + source +
476                                 "  if (gl_InvocationID == 0)\n"
477                                 "  {\n"
478                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
479                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
480                                 "  }\n"
481                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
482                                 "}\n";
483
484                         programCollection.glslSources.add("tesc")
485                                         << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
486                 }
487
488                 {
489                         const string tese =
490                                 "#version 450\n"
491                                 "#extension GL_KHR_shader_subgroup_vote: enable\n"
492                                 "layout(isolines) in;\n"
493                                 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
494                                 "{\n"
495                                 "  uint result[];\n"
496                                 "};\n"
497                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
498                                 "{\n"
499                                 "  " + formatString + " data[];\n"
500                                 "};\n"
501                                 "\n"
502                                 "void main (void)\n"
503                                 "{\n"
504                                 "  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n"
505                                 + source +
506                                 "  float pixelSize = 2.0f/1024.0f;\n"
507                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
508                                 "}\n";
509
510                         programCollection.glslSources.add("tese")
511                                         << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
512                 }
513
514                 {
515                         const string geometry =
516                                 "#version 450\n"
517                                 "#extension GL_KHR_shader_subgroup_vote: enable\n"
518                                 "layout(${TOPOLOGY}) in;\n"
519                                 "layout(points, max_vertices = 1) out;\n"
520                                 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
521                                 "{\n"
522                                 "  uint result[];\n"
523                                 "};\n"
524                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
525                                 "{\n"
526                                 "  " + formatString + " data[];\n"
527                                 "};\n"
528                                 "\n"
529                                 "void main (void)\n"
530                                 "{\n"
531                                 "  highp uint offset = gl_PrimitiveIDIn;\n"
532                                 + source +
533                                 "  gl_Position = gl_in[0].gl_Position;\n"
534                                 "  EmitVertex();\n"
535                                 "  EndPrimitive();\n"
536                                 "}\n";
537
538                         subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
539                                                                                                           programCollection.glslSources);
540                 }
541
542                 {
543                         const string sourceFragment =
544                         (OPTYPE_ALL == caseDef.opType) ?
545                                 "  result = " + getOpTypeName(caseDef.opType) +
546                                 "(true) ? 0x1 : 0;\n"
547                                 "  result |= " + getOpTypeName(caseDef.opType) +
548                                 "(false) ? 0 : 0x1A;\n"
549                                 "  result |= 0x4;\n"
550                         : (OPTYPE_ANY == caseDef.opType) ?
551                                         "  result = " + getOpTypeName(caseDef.opType) +
552                                         "(true) ? 0x1 : 0;\n"
553                                         "  result |= " + getOpTypeName(caseDef.opType) +
554                                         "(false) ? 0 : 0x1A;\n"
555                                         "  result |= 0x4;\n"
556                         : (OPTYPE_ALLEQUAL == caseDef.opType) ?
557                                         "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
558                                         "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + int(gl_FragCoord.x*gl_SubgroupInvocationID));\n") +
559                                         "  result = " + getOpTypeName(caseDef.opType) + "("
560                                         + subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
561                                         "  result |= " + getOpTypeName(caseDef.opType) +
562                                         "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
563                                         "  result |= " + getOpTypeName(caseDef.opType) +
564                                         "(data[0]) ? 0x4 : 0;\n"
565                                         "  result |= " + getOpTypeName(caseDef.opType) +
566                                         "(valueEqual) ? 0x8 : 0x0;\n"
567                                         "  result |= " + getOpTypeName(caseDef.opType) +
568                                         "(valueNoEqual) ? 0x0 : 0x10;\n"
569                                         "  if (subgroupElect()) result |= 0x2 | 0x10;\n"
570                         : "";
571                         const string fragment =
572                                 "#version 450\n"
573                                 "#extension GL_KHR_shader_subgroup_vote: enable\n"
574                                 "layout(location = 0) out uint result;\n"
575                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
576                                 "{\n"
577                                 "  " + formatString + " data[];\n"
578                                 "};\n"
579                                 "void main (void)\n"
580                                 "{\n"
581                                 + sourceFragment +
582                                 "}\n";
583
584                         programCollection.glslSources.add("fragment")
585                                 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
586                 }
587
588                 subgroups::addNoSubgroupShader(programCollection);
589         }
590 }
591
592 void supportedCheck (Context& context, CaseDefinition caseDef)
593 {
594         if (!subgroups::isSubgroupSupported(context))
595                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
596
597         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_VOTE_BIT))
598         {
599                 TCU_THROW(NotSupportedError, "Device does not support subgroup vote operations");
600         }
601
602         if (subgroups::isDoubleFormat(caseDef.format) &&
603                         !subgroups::isDoubleSupportedForDevice(context))
604         {
605                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
606         }
607 }
608
609 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
610 {
611         if (!subgroups::areSubgroupOperationsSupportedForStage(
612                                 context, caseDef.shaderStage))
613         {
614                 if (subgroups::areSubgroupOperationsRequiredForStage(
615                                         caseDef.shaderStage))
616                 {
617                         return tcu::TestStatus::fail(
618                                            "Shader stage " +
619                                            subgroups::getShaderStageName(caseDef.shaderStage) +
620                                            " is required to support subgroup operations!");
621                 }
622                 else
623                 {
624                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
625                 }
626         }
627
628         subgroups::SSBOData inputData;
629         inputData.format = caseDef.format;
630         inputData.numElements = subgroups::maxSupportedSubgroupSize();
631         inputData.initializeType = OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
632
633         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
634                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
635         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
636                 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
637         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
638                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
639         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
640                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
641         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
642                 return subgroups::makeFragmentFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkFragmentPipelineStages);
643         else
644                 TCU_THROW(InternalError, "Unhandled shader stage");
645 }
646
647
648 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
649 {
650         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
651         {
652                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
653                 {
654                         return tcu::TestStatus::fail(
655                                            "Shader stage " +
656                                            subgroups::getShaderStageName(caseDef.shaderStage) +
657                                            " is required to support subgroup operations!");
658                 }
659
660                 subgroups::SSBOData inputData;
661                 inputData.format = caseDef.format;
662                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
663                 inputData.initializeType = OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
664
665                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
666                                                                                   1, checkCompute);
667         }
668         else
669         {
670                 VkPhysicalDeviceSubgroupProperties subgroupProperties;
671                 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
672                 subgroupProperties.pNext = DE_NULL;
673
674                 VkPhysicalDeviceProperties2 properties;
675                 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
676                 properties.pNext = &subgroupProperties;
677
678                 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
679
680                 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
681
682                 if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
683                 {
684                         if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
685                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
686                         else
687                                 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
688                 }
689
690                 if ((VkShaderStageFlagBits)0u == stages)
691                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
692
693                 subgroups::SSBOData inputData;
694                 inputData.format                        = caseDef.format;
695                 inputData.numElements           = subgroups::maxSupportedSubgroupSize();
696                 inputData.initializeType        = OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
697                 inputData.binding                       = 4u;
698                 inputData.stages                        = stages;
699
700                 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
701         }
702 }
703 }
704
705 namespace vkt
706 {
707 namespace subgroups
708 {
709 tcu::TestCaseGroup* createSubgroupsVoteTests(tcu::TestContext& testCtx)
710 {
711         de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
712                 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
713         de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
714                 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
715         de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
716                 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
717
718         de::MovePtr<tcu::TestCaseGroup> fragHelperGroup(new tcu::TestCaseGroup(
719                 testCtx, "frag_helper", "Subgroup arithmetic category tests: fragment helper invocation"));
720
721         const VkShaderStageFlags stages[] =
722         {
723                 VK_SHADER_STAGE_VERTEX_BIT,
724                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
725                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
726                 VK_SHADER_STAGE_GEOMETRY_BIT,
727         };
728
729         const VkFormat formats[] =
730         {
731                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
732                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
733                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
734                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
735                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
736                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
737                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
738                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
739                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
740         };
741
742         for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
743         {
744                 const VkFormat format = formats[formatIndex];
745
746                 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
747                 {
748                         // Skip the typed tests for all but subgroupAllEqual()
749                         if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opTypeIndex))
750                         {
751                                 continue;
752                         }
753
754                         const std::string op = de::toLower(getOpTypeName(opTypeIndex));
755
756                         {
757                                 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
758                                 addFunctionCaseWithPrograms(computeGroup.get(),
759                                                                                         op + "_" + subgroups::getFormatNameForGLSL(format),
760                                                                                         "", supportedCheck, initPrograms, test, caseDef);
761                         }
762
763                         {
764                                 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format};
765                                 addFunctionCaseWithPrograms(graphicGroup.get(),
766                                                                                         op + "_" + subgroups::getFormatNameForGLSL(format),
767                                                                                         "", supportedCheck, initPrograms, test, caseDef);
768                         }
769
770                         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
771                         {
772                                 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
773                                 addFunctionCaseWithPrograms(framebufferGroup.get(),
774                                                         op + "_" +
775                                                         subgroups::getFormatNameForGLSL(format)
776                                                         + "_" + getShaderStageName(caseDef.shaderStage), "",
777                                                         supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
778                         }
779
780                         const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_FRAGMENT_BIT, format};
781                         addFunctionCaseWithPrograms(fragHelperGroup.get(),
782                                                 op + "_" +
783                                                 subgroups::getFormatNameForGLSL(format)
784                                                 + "_" + getShaderStageName(caseDef.shaderStage), "",
785                                                 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
786                 }
787         }
788
789         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
790                 testCtx, "vote", "Subgroup vote category tests"));
791
792         group->addChild(graphicGroup.release());
793         group->addChild(computeGroup.release());
794         group->addChild(framebufferGroup.release());
795         group->addChild(fragHelperGroup.release());
796
797         return group.release();
798 }
799
800 } // subgroups
801 } // vkt