porting changes for OpenGL Subgroup tests
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsQuadTests.cpp
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsQuadTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43         OPTYPE_QUAD_BROADCAST = 0,
44         OPTYPE_QUAD_SWAP_HORIZONTAL,
45         OPTYPE_QUAD_SWAP_VERTICAL,
46         OPTYPE_QUAD_SWAP_DIAGONAL,
47         OPTYPE_LAST
48 };
49
50 static bool checkVertexPipelineStages(std::vector<const void*> datas,
51                                                                           deUint32 width, deUint32)
52 {
53         return glc::subgroups::check(datas, width, 1);
54 }
55
56 static bool checkComputeStage(std::vector<const void*> datas,
57                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
58                                                  deUint32)
59 {
60         return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
61 }
62
63 std::string getOpTypeName(int opType)
64 {
65         switch (opType)
66         {
67                 default:
68                         DE_FATAL("Unsupported op type");
69                         return "";
70                 case OPTYPE_QUAD_BROADCAST:
71                         return "subgroupQuadBroadcast";
72                 case OPTYPE_QUAD_SWAP_HORIZONTAL:
73                         return "subgroupQuadSwapHorizontal";
74                 case OPTYPE_QUAD_SWAP_VERTICAL:
75                         return "subgroupQuadSwapVertical";
76                 case OPTYPE_QUAD_SWAP_DIAGONAL:
77                         return "subgroupQuadSwapDiagonal";
78         }
79 }
80
81 struct CaseDefinition
82 {
83         int                                     opType;
84         ShaderStageFlags        shaderStage;
85         Format                          format;
86         int                                     direction;
87 };
88
89 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
90 {
91         std::string                     swapTable[OPTYPE_LAST];
92
93         subgroups::setFragmentShaderFrameBuffer(programCollection);
94
95         if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
96                 subgroups::setVertexShaderFrameBuffer(programCollection);
97
98         swapTable[OPTYPE_QUAD_BROADCAST] = "";
99         swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
100         swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
101         swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
102
103         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
104         {
105                 std::ostringstream      vertexSrc;
106                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
107                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
108                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
109                         << "layout(location = 0) in highp vec4 in_position;\n"
110                         << "layout(location = 0) out float result;\n"
111                         << "layout(binding = 0) uniform Buffer0\n"
112                         << "{\n"
113                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
114                         << "};\n"
115                         << "\n"
116                         << "void main (void)\n"
117                         << "{\n"
118                         << "  uvec4 mask = subgroupBallot(true);\n"
119                         << swapTable[caseDef.opType];
120
121                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
122                 {
123                         vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
124                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
125                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
126                 }
127                 else
128                 {
129                         vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
130                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
131                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
132                 }
133
134                 vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
135                         << "  {\n"
136                         << "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
137                         << "  }\n"
138                         << "  else\n"
139                         << "  {\n"
140                         << "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
141                         << "  }\n"
142                         << "  gl_Position = in_position;\n"
143                         << "  gl_PointSize = 1.0f;\n"
144                         << "}\n";
145                 programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
146         }
147         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
148         {
149                 std::ostringstream geometry;
150
151                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
152                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
153                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
154                         << "layout(points) in;\n"
155                         << "layout(points, max_vertices = 1) out;\n"
156                         << "layout(location = 0) out float out_color;\n"
157                         << "layout(binding = 0) uniform Buffer0\n"
158                         << "{\n"
159                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160                         << "};\n"
161                         << "\n"
162                         << "void main (void)\n"
163                         << "{\n"
164                         << "  uvec4 mask = subgroupBallot(true);\n"
165                         << swapTable[caseDef.opType];
166
167                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
168                 {
169                         geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
171                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
172                 }
173                 else
174                 {
175                         geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
178                 }
179
180                 geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
181                         << "  {\n"
182                         << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
183                         << "  }\n"
184                         << "  else\n"
185                         << "  {\n"
186                         << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
187                         << "  }\n"
188                         << "  gl_Position = gl_in[0].gl_Position;\n"
189                         << "  EmitVertex();\n"
190                         << "  EndPrimitive();\n"
191                         << "}\n";
192
193                 programCollection.add("geometry") << glu::GeometrySource(geometry.str());
194         }
195         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
196         {
197                 std::ostringstream controlSource;
198
199                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
200                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
201                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
202                         << "layout(vertices = 2) out;\n"
203                         << "layout(location = 0) out float out_color[];\n"
204                         << "layout(binding = 0) uniform Buffer0\n"
205                         << "{\n"
206                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
207                         << "};\n"
208                         << "\n"
209                         << "void main (void)\n"
210                         << "{\n"
211                         << "  if (gl_InvocationID == 0)\n"
212                         <<"  {\n"
213                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
214                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
215                         << "  }\n"
216                         << "  uvec4 mask = subgroupBallot(true);\n"
217                         << swapTable[caseDef.opType];
218
219                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
220                 {
221                         controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
222                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
223                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
224                 }
225                 else
226                 {
227                         controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
228                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
229                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
230                 }
231
232                 controlSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
233                         << "  {\n"
234                         << "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
235                         << "  }\n"
236                         << "  else\n"
237                         << "  {\n"
238                         << "    out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
239                         << "  }\n"
240                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
241                         << "}\n";
242
243                 programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
244                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
245         }
246         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
247         {
248                 ostringstream evaluationSource;
249                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
250                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
251                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
252                         << "layout(isolines, equal_spacing, ccw ) in;\n"
253                         << "layout(location = 0) out float out_color;\n"
254                         << "layout(binding = 0) uniform Buffer0\n"
255                         << "{\n"
256                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
257                         << "};\n"
258                         << "\n"
259                         << "void main (void)\n"
260                         << "{\n"
261                         << "  uvec4 mask = subgroupBallot(true);\n"
262                         << swapTable[caseDef.opType];
263
264                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
265                 {
266                         evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
267                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
268                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
269                 }
270                 else
271                 {
272                         evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
273                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
274                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
275                 }
276
277                 evaluationSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
278                         << "  {\n"
279                         << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
280                         << "  }\n"
281                         << "  else\n"
282                         << "  {\n"
283                         << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
284                         << "  }\n"
285                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
286                         << "}\n";
287
288                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
289                 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
290         }
291         else
292         {
293                 DE_FATAL("Unsupported shader stage");
294         }
295 }
296
297 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
298 {
299         std::string swapTable[OPTYPE_LAST];
300         swapTable[OPTYPE_QUAD_BROADCAST] = "";
301         swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
302         swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
303         swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
304
305         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
306         {
307                 std::ostringstream src;
308
309                 src << "#version 450\n"
310                         << "#extension GL_KHR_shader_subgroup_quad: enable\n"
311                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
312                         << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
313                         << "layout(binding = 0, std430) buffer Buffer0\n"
314                         << "{\n"
315                         << "  uint result[];\n"
316                         << "};\n"
317                         << "layout(binding = 1, std430) buffer Buffer1\n"
318                         << "{\n"
319                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
320                         << "};\n"
321                         << "\n"
322                         << "void main (void)\n"
323                         << "{\n"
324                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
325                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
326                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
327                         "gl_GlobalInvocationID.x;\n"
328                         << "  uvec4 mask = subgroupBallot(true);\n"
329                         << swapTable[caseDef.opType];
330
331
332                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
333                 {
334                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
335                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
336                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
337                 }
338                 else
339                 {
340                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
341                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
342                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
343                 }
344
345                 src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
346                         << "  {\n"
347                         << "    result[offset] = (op == data[otherID]) ? 1 : 0;\n"
348                         << "  }\n"
349                         << "  else\n"
350                         << "  {\n"
351                         << "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
352                         << "  }\n"
353                         << "}\n";
354
355                 programCollection.add("comp") << glu::ComputeSource(src.str());
356         }
357         else
358         {
359                 std::ostringstream src;
360                 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
361                 {
362                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
363                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
364                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
365                 }
366                 else
367                 {
368                         src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
369                                 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
370                                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
371                 }
372                 const string sourceType = src.str();
373
374                 {
375                         const string vertex =
376                                 "#version 450\n"
377                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
378                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
379                                 "layout(binding = 0, std430) buffer Buffer0\n"
380                                 "{\n"
381                                 "  uint result[];\n"
382                                 "} b0;\n"
383                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
384                                 "{\n"
385                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
386                                 "};\n"
387                                 "\n"
388                                 "void main (void)\n"
389                                 "{\n"
390                                 "  uvec4 mask = subgroupBallot(true);\n"
391                                 + swapTable[caseDef.opType]
392                                 + sourceType +
393                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
394                                 "  {\n"
395                                 "    b0.result[gl_VertexID] = (op == data[otherID]) ? 1 : 0;\n"
396                                 "  }\n"
397                                 "  else\n"
398                                 "  {\n"
399                                 "    b0.result[gl_VertexID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
400                                 "  }\n"
401                                 "  float pixelSize = 2.0f/1024.0f;\n"
402                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
403                                 "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
404                                 "}\n";
405                         programCollection.add("vert") << glu::VertexSource(vertex);
406                 }
407
408                 {
409                         const string tesc =
410                                 "#version 450\n"
411                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
412                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
413                                 "layout(vertices=1) out;\n"
414                                 "layout(binding = 1, std430) buffer Buffer1\n"
415                                 "{\n"
416                                 "  uint result[];\n"
417                                 "} b1;\n"
418                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
419                                 "{\n"
420                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
421                                 "};\n"
422                                 "\n"
423                                 "void main (void)\n"
424                                 "{\n"
425                                 "  uvec4 mask = subgroupBallot(true);\n"
426                                 + swapTable[caseDef.opType]
427                                 + sourceType +
428                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
429                                 "  {\n"
430                                 "    b1.result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
431                                 "  }\n"
432                                 "  else\n"
433                                 "  {\n"
434                                 "    b1.result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
435                                 "  }\n"
436                                 "  if (gl_InvocationID == 0)\n"
437                                 "  {\n"
438                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
439                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
440                                 "  }\n"
441                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
442                                 "}\n";
443                         programCollection.add("tesc") << glu::TessellationControlSource(tesc);
444                 }
445
446                 {
447                         const string tese =
448                                 "#version 450\n"
449                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
450                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
451                                 "layout(isolines) in;\n"
452                                 "layout(binding = 2, std430)  buffer Buffer2\n"
453                                 "{\n"
454                                 "  uint result[];\n"
455                                 "} b2;\n"
456                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
457                                 "{\n"
458                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
459                                 "};\n"
460                                 "\n"
461                                 "void main (void)\n"
462                                 "{\n"
463                                 "  uvec4 mask = subgroupBallot(true);\n"
464                                 + swapTable[caseDef.opType]
465                                 + sourceType +
466                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
467                                 "  {\n"
468                                 "    b2.result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
469                                 "  }\n"
470                                 "  else\n"
471                                 "  {\n"
472                                 "    b2.result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
473                                 "  }\n"
474                                 "  float pixelSize = 2.0f/1024.0f;\n"
475                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
476                                 "}\n";
477                         programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
478                 }
479
480                 {
481                         const string geometry =
482                                 "#version 450\n"
483                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
484                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
485                                 "layout(${TOPOLOGY}) in;\n"
486                                 "layout(points, max_vertices = 1) out;\n"
487                                 "layout(binding = 3, std430) buffer Buffer3\n"
488                                 "{\n"
489                                 "  uint result[];\n"
490                                 "} b3;\n"
491                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
492                                 "{\n"
493                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
494                                 "};\n"
495                                 "\n"
496                                 "void main (void)\n"
497                                 "{\n"
498                                 "  uvec4 mask = subgroupBallot(true);\n"
499                                 + swapTable[caseDef.opType]
500                                 + sourceType +
501                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
502                                 "  {\n"
503                                 "    b3.result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
504                                 "  }\n"
505                                 "  else\n"
506                                 "  {\n"
507                                 "    b3.result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
508                                 "  }\n"
509                                 "  gl_Position = gl_in[0].gl_Position;\n"
510                                 "  EmitVertex();\n"
511                                 "  EndPrimitive();\n"
512                                 "}\n";
513                         subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
514                 }
515
516                 {
517                         const string fragment =
518                                 "#version 450\n"
519                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
520                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
521                                 "layout(location = 0) out uint result;\n"
522                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
523                                 "{\n"
524                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
525                                 "};\n"
526                                 "void main (void)\n"
527                                 "{\n"
528                                 "  uvec4 mask = subgroupBallot(true);\n"
529                                 + swapTable[caseDef.opType]
530                                 + sourceType +
531                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
532                                 "  {\n"
533                                 "    result = (op == data[otherID]) ? 1 : 0;\n"
534                                 "  }\n"
535                                 "  else\n"
536                                 "  {\n"
537                                 "    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
538                                 "  }\n"
539                                 "}\n";
540                         programCollection.add("fragment") << glu::FragmentSource(fragment);
541                 }
542                 subgroups::addNoSubgroupShader(programCollection);
543         }
544 }
545
546 void supportedCheck (Context& context, CaseDefinition caseDef)
547 {
548         if (!subgroups::isSubgroupSupported(context))
549                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
550
551         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_QUAD_BIT))
552                 TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
553
554
555         if (subgroups::isDoubleFormat(caseDef.format) &&
556                         !subgroups::isDoubleSupportedForDevice(context))
557         {
558                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
559         }
560 }
561
562 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
563 {
564         if (!subgroups::areSubgroupOperationsSupportedForStage(
565                                 context, caseDef.shaderStage))
566         {
567                 if (subgroups::areSubgroupOperationsRequiredForStage(
568                                         caseDef.shaderStage))
569                 {
570                         return tcu::TestStatus::fail(
571                                            "Shader stage " +
572                                            subgroups::getShaderStageName(caseDef.shaderStage) +
573                                            " is required to support subgroup operations!");
574                 }
575                 else
576                 {
577                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
578                 }
579         }
580
581         subgroups::SSBOData inputData;
582         inputData.format = caseDef.format;
583         inputData.numElements = subgroups::maxSupportedSubgroupSize();
584         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
585         inputData.binding = 0u;
586
587         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
588                 return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
589         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
590                 return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
591         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
592                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
593         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
594                 return subgroups::makeTessellationEvaluationFrameBufferTest(context,  FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
595         else
596                 TCU_THROW(InternalError, "Unhandled shader stage");
597 }
598
599
600 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
601 {
602         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
603         {
604                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
605                 {
606                         return tcu::TestStatus::fail(
607                                            "Shader stage " +
608                                            subgroups::getShaderStageName(caseDef.shaderStage) +
609                                            " is required to support subgroup operations!");
610                 }
611                 subgroups::SSBOData inputData;
612                 inputData.format = caseDef.format;
613                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
614                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
615                 inputData.binding = 1u;
616
617                 return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
618         }
619         else
620         {
621                 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
622
623                 ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
624
625                 if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
626                 {
627                         if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
628                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
629                         else
630                                 stages = SHADER_STAGE_FRAGMENT_BIT;
631                 }
632
633                 if ((ShaderStageFlags)0u == stages)
634                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
635
636                 subgroups::SSBOData inputData;
637                 inputData.format                        = caseDef.format;
638                 inputData.numElements           = subgroups::maxSupportedSubgroupSize();
639                 inputData.initializeType        = subgroups::SSBOData::InitializeNonZero;
640                 inputData.binding                       = 4u;
641                 inputData.stages                        = stages;
642
643                 return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
644         }
645 }
646 }
647
648 deqp::TestCaseGroup* createSubgroupsQuadTests(deqp::Context& testCtx)
649 {
650         de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
651                 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
652         de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
653                 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
654         de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
655                 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
656
657         const Format formats[] =
658         {
659                 FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT,
660                 FORMAT_R32G32B32A32_SINT, FORMAT_R32_UINT, FORMAT_R32G32_UINT,
661                 FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
662                 FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT,
663                 FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
664                 FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT,
665                 FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
666                 FORMAT_R32_BOOL, FORMAT_R32G32_BOOL,
667                 FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
668         };
669
670         const ShaderStageFlags stages[] =
671         {
672                 SHADER_STAGE_VERTEX_BIT,
673                 SHADER_STAGE_TESS_EVALUATION_BIT,
674                 SHADER_STAGE_TESS_CONTROL_BIT,
675                 SHADER_STAGE_GEOMETRY_BIT,
676         };
677
678         for (int direction = 0; direction < 4; ++direction)
679         {
680                 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
681                 {
682                         const Format format = formats[formatIndex];
683
684                         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
685                         {
686                                 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
687                                 std::ostringstream name;
688                                 name << de::toLower(op);
689
690                                 if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
691                                 {
692                                         name << "_" << direction;
693                                 }
694                                 else
695                                 {
696                                         if (0 != direction)
697                                         {
698                                                 // We don't need direction for swap operations.
699                                                 continue;
700                                         }
701                                 }
702
703                                 name << "_" << subgroups::getFormatNameForGLSL(format);
704
705                                 {
706                                         const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format, direction};
707                                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
708                                 }
709
710                                 {
711                                         const CaseDefinition caseDef =
712                                         {
713                                                 opTypeIndex,
714                                                 SHADER_STAGE_ALL_GRAPHICS,
715                                                 format,
716                                                 direction
717                                         };
718                                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
719                                 }
720                                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
721                                 {
722                                         const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
723                                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
724                                                                                                 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
725                                 }
726
727                         }
728                 }
729         }
730
731         de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
732                 testCtx, "quad", "Subgroup quad category tests"));
733
734         group->addChild(graphicGroup.release());
735         group->addChild(computeGroup.release());
736         group->addChild(framebufferGroup.release());
737
738         return group.release();
739 }
740 } // subgroups
741 } // glc