porting changes for OpenGL Subgroup tests
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsBallotOtherTests.cpp
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsBallotOtherTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43         OPTYPE_INVERSE_BALLOT = 0,
44         OPTYPE_BALLOT_BIT_EXTRACT,
45         OPTYPE_BALLOT_BIT_COUNT,
46         OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT,
47         OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT,
48         OPTYPE_BALLOT_FIND_LSB,
49         OPTYPE_BALLOT_FIND_MSB,
50         OPTYPE_LAST
51 };
52
53 static bool checkVertexPipelineStages(std::vector<const void*> datas,
54                                                                           deUint32 width, deUint32)
55 {
56         return glc::subgroups::check(datas, width, 0xf);
57 }
58
59 static bool checkComputeStage(std::vector<const void*> datas,
60                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
61                                                  deUint32)
62 {
63         return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 0xf);
64 }
65
66 std::string getOpTypeName(int opType)
67 {
68         switch (opType)
69         {
70                 default:
71                         DE_FATAL("Unsupported op type");
72                         return "";
73                 case OPTYPE_INVERSE_BALLOT:
74                         return "subgroupInverseBallot";
75                 case OPTYPE_BALLOT_BIT_EXTRACT:
76                         return "subgroupBallotBitExtract";
77                 case OPTYPE_BALLOT_BIT_COUNT:
78                         return "subgroupBallotBitCount";
79                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
80                         return "subgroupBallotInclusiveBitCount";
81                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
82                         return "subgroupBallotExclusiveBitCount";
83                 case OPTYPE_BALLOT_FIND_LSB:
84                         return "subgroupBallotFindLSB";
85                 case OPTYPE_BALLOT_FIND_MSB:
86                         return "subgroupBallotFindMSB";
87         }
88 }
89
90 struct CaseDefinition
91 {
92         int                                     opType;
93         ShaderStageFlags        shaderStage;
94 };
95
96 std::string getBodySource(CaseDefinition caseDef)
97 {
98         std::ostringstream bdy;
99
100         bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
101                 << "  uvec4 allZeros = uvec4(0);\n"
102                 << "  uint tempResult = 0;\n"
103                 << "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
104                 << "i >= 32 ? 0 : (0xFFFFFFFF << i), "
105                 << "i >= 64 ? 0 : (0xFFFFFFFF << ((i < 32) ? 0 : (i - 32))), "
106                 << "i >= 96 ? 0 : (0xFFFFFFFF << ((i < 64) ? 0 : (i - 64))), "
107                 << " 0xFFFFFFFF << ((i < 96) ? 0 : (i - 96)))\n"
108                 << "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
109                 << "i >= 32 ? 0 : 0x1 << i, "
110                 << "i < 32 || i >= 64 ? 0 : 0x1 << (i - 32), "
111                 << "i < 64 || i >= 96 ? 0 : 0x1 << (i - 64), "
112                 << "i < 96 ? 0 : 0x1 << (i - 96))\n";
113
114         switch (caseDef.opType)
115         {
116                 default:
117                         DE_FATAL("Unknown op type!");
118                         break;
119                 case OPTYPE_INVERSE_BALLOT:
120                         bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1 : 0;\n"
121                                 << "  tempResult |= subgroupInverseBallot(allZeros) ? 0 : 0x2;\n"
122                                 << "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4 : 0;\n"
123                                 << "  tempResult |= 0x8;\n";
124                         break;
125                 case OPTYPE_BALLOT_BIT_EXTRACT:
126                         bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1 : 0;\n"
127                                 << "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0 : 0x2;\n"
128                                 << "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4 : 0;\n"
129                                 << "  tempResult |= 0x8;\n"
130                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
131                                 << "  {\n"
132                                 << "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
133                                 << "    {\n"
134                                 << "      tempResult &= ~0x8;\n"
135                                 << "    }\n"
136                                 << "  }\n";
137                         break;
138                 case OPTYPE_BALLOT_BIT_COUNT:
139                         bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1 : 0;\n"
140                                 << "  tempResult |= 0 == subgroupBallotBitCount(allZeros) ? 0x2 : 0;\n"
141                                 << "  tempResult |= 0 < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
142                                 << "  tempResult |= 0 == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8 : 0;\n";
143                         break;
144                 case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
145                         bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1;\n"
146                                 << "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1 : 0;\n"
147                                 << "  tempResult |= 0 == subgroupBallotInclusiveBitCount(allZeros) ? 0x2 : 0;\n"
148                                 << "  tempResult |= 0 < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
149                                 << "  tempResult |= 0x8;\n"
150                                 << "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
151                                 << "  bool undefTerritory = false;\n"
152                                 << "  for (uint i = 0; i <= 128; i++)\n"
153                                 << "  {\n"
154                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
155                                 << "    if (iUndef == inclusiveUndef)"
156                                 << "    {\n"
157                                 << "      undefTerritory = true;\n"
158                                 << "    }\n"
159                                 << "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
160                                 << "    if (undefTerritory && (0 != inclusiveBitCount))\n"
161                                 << "    {\n"
162                                 << "      tempResult &= ~0x8;\n"
163                                 << "    }\n"
164                                 << "    else if (!undefTerritory && (0 == inclusiveBitCount))\n"
165                                 << "    {\n"
166                                 << "      tempResult &= ~0x8;\n"
167                                 << "    }\n"
168                                 << "  }\n";
169                         break;
170                 case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
171                         bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
172                                 << "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1 : 0;\n"
173                                 << "  tempResult |= 0 == subgroupBallotExclusiveBitCount(allZeros) ? 0x2 : 0;\n"
174                                 << "  tempResult |= 0x4;\n"
175                                 << "  tempResult |= 0x8;\n"
176                                 << "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
177                                 << "  bool undefTerritory = false;\n"
178                                 << "  for (uint i = 0; i <= 128; i++)\n"
179                                 << "  {\n"
180                                 << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
181                                 << "    if (iUndef == exclusiveUndef)"
182                                 << "    {\n"
183                                 << "      undefTerritory = true;\n"
184                                 << "    }\n"
185                                 << "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
186                                 << "    if (undefTerritory && (0 != exclusiveBitCount))\n"
187                                 << "    {\n"
188                                 << "      tempResult &= ~0x4;\n"
189                                 << "    }\n"
190                                 << "    else if (!undefTerritory && (0 == exclusiveBitCount))\n"
191                                 << "    {\n"
192                                 << "      tempResult &= ~0x8;\n"
193                                 << "    }\n"
194                                 << "  }\n";
195                         break;
196                 case OPTYPE_BALLOT_FIND_LSB:
197                         bdy << "  tempResult |= 0 == subgroupBallotFindLSB(allOnes) ? 0x1 : 0;\n"
198                                 << "  if (subgroupElect())\n"
199                                 << "  {\n"
200                                 << "    tempResult |= 0x2;\n"
201                                 << "  }\n"
202                                 << "  else\n"
203                                 << "  {\n"
204                                 << "    tempResult |= 0 < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2 : 0;\n"
205                                 << "  }\n"
206                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4 : 0;\n"
207                                 << "  tempResult |= 0x8;\n"
208                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
209                                 << "  {\n"
210                                 << "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
211                                 << "    {\n"
212                                 << "      tempResult &= ~0x8;\n"
213                                 << "    }\n"
214                                 << "  }\n";
215                         break;
216                 case OPTYPE_BALLOT_FIND_MSB:
217                         bdy << "  tempResult |= (gl_SubgroupSize - 1) == subgroupBallotFindMSB(allOnes) ? 0x1 : 0;\n"
218                                 << "  if (subgroupElect())\n"
219                                 << "  {\n"
220                                 << "    tempResult |= 0x2;\n"
221                                 << "  }\n"
222                                 << "  else\n"
223                                 << "  {\n"
224                                 << "    tempResult |= 0 < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2 : 0;\n"
225                                 << "  }\n"
226                                 << "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4 : 0;\n"
227                                 << "  tempResult |= 0x8;\n"
228                                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
229                                 << "  {\n"
230                                 << "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
231                                 << "    {\n"
232                                 << "      tempResult &= ~0x8;\n"
233                                 << "    }\n"
234                                 << "  }\n";
235                         break;
236         }
237    return bdy.str();
238 }
239
240 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
241 {
242         subgroups::setFragmentShaderFrameBuffer(programCollection);
243
244         if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
245                 subgroups::setVertexShaderFrameBuffer(programCollection);
246
247         std::string bdyStr = getBodySource(caseDef);
248
249         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
250         {
251                 std::ostringstream                              vertex;
252                 vertex << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
253                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
254                         << "layout(location = 0) in highp vec4 in_position;\n"
255                         << "layout(location = 0) out float out_color;\n"
256                         << "\n"
257                         << "void main (void)\n"
258                         << "{\n"
259                         << bdyStr
260                         << "  out_color = float(tempResult);\n"
261                         << "  gl_Position = in_position;\n"
262                         << "  gl_PointSize = 1.0f;\n"
263                         << "}\n";
264                 programCollection.add("vert") << glu::VertexSource(vertex.str());
265         }
266         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
267         {
268                 std::ostringstream geometry;
269
270                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
271                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
272                         << "layout(points) in;\n"
273                         << "layout(points, max_vertices = 1) out;\n"
274                         << "layout(location = 0) out float out_color;\n"
275                         << "void main (void)\n"
276                         << "{\n"
277                         << bdyStr
278                         << "  out_color = float(tempResult);\n"
279                         << "  gl_Position = gl_in[0].gl_Position;\n"
280                         << "  EmitVertex();\n"
281                         << "  EndPrimitive();\n"
282                         << "}\n";
283
284                 programCollection.add("geometry") << glu::GeometrySource(geometry.str());
285         }
286         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
287         {
288                 std::ostringstream controlSource;
289
290                 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
291                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
292                         << "layout(vertices = 2) out;\n"
293                         << "layout(location = 0) out float out_color[];\n"
294                         << "\n"
295                         << "void main (void)\n"
296                         << "{\n"
297                         << "  if (gl_InvocationID == 0)\n"
298                         << "  {\n"
299                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
300                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
301                         << "  }\n"
302                         << bdyStr
303                         << "  out_color[gl_InvocationID ] = float(tempResult);\n"
304                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
305                         << "}\n";
306
307                 programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
308                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
309         }
310         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
311         {
312                 std::ostringstream evaluationSource;
313                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
314                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
315                         << "layout(isolines, equal_spacing, ccw ) in;\n"
316                         << "layout(location = 0) out float out_color;\n"
317                         << "void main (void)\n"
318                         << "{\n"
319                         << bdyStr
320                         << "  out_color  = float(tempResult);\n"
321                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
322                         << "}\n";
323
324                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
325                 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
326         }
327         else
328         {
329                 DE_FATAL("Unsupported shader stage");
330         }
331 }
332
333 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
334 {
335         std::string bdyStr = getBodySource(caseDef);
336
337         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
338         {
339                 std::ostringstream src;
340
341                 src << "#version 450\n"
342                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
343                         << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
344                         << "layout(binding = 0, std430) buffer Buffer0\n"
345                         << "{\n"
346                         << "  uint result[];\n"
347                         << "};\n"
348                         << "\n"
349                         << "void main (void)\n"
350                         << "{\n"
351                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
352                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
353                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
354                         "gl_GlobalInvocationID.x;\n"
355                         << bdyStr
356                         << "  result[offset] = tempResult;\n"
357                         << "}\n";
358
359                 programCollection.add("comp") << glu::ComputeSource(src.str());
360         }
361         else
362         {
363                 const string vertex =
364                         "#version 450\n"
365                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
366                         "layout(binding = 0, std430) buffer Buffer0\n"
367                         "{\n"
368                         "  uint result[];\n"
369                         "} b0;\n"
370                         "\n"
371                         "void main (void)\n"
372                         "{\n"
373                         + bdyStr +
374                         "  b0.result[gl_VertexID] = tempResult;\n"
375                         "  float pixelSize = 2.0f/1024.0f;\n"
376                         "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
377                         "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
378                         "  gl_PointSize = 1.0f;\n"
379                         "}\n";
380
381                 const string tesc =
382                         "#version 450\n"
383                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
384                         "layout(vertices=1) out;\n"
385                         "layout(binding = 1, std430) buffer Buffer1\n"
386                         "{\n"
387                         "  uint result[];\n"
388                         "} b1;\n"
389                         "\n"
390                         "void main (void)\n"
391                         "{\n"
392                         + bdyStr +
393                         "  b1.result[gl_PrimitiveID] = tempResult;\n"
394                         "  if (gl_InvocationID == 0)\n"
395                         "  {\n"
396                         "    gl_TessLevelOuter[0] = 1.0f;\n"
397                         "    gl_TessLevelOuter[1] = 1.0f;\n"
398                         "  }\n"
399                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
400                         "}\n";
401
402                 const string tese =
403                         "#version 450\n"
404                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
405                         "layout(isolines) in;\n"
406                         "layout(binding = 2, std430) buffer Buffer2\n"
407                         "{\n"
408                         "  uint result[];\n"
409                         "} b2;\n"
410                         "\n"
411                         "void main (void)\n"
412                         "{\n"
413                         + bdyStr +
414                         "  b2.result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
415                         "  float pixelSize = 2.0f/1024.0f;\n"
416                         "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
417                         "}\n";
418
419                 const string geometry =
420                         "#version 450\n"
421                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
422                         "layout(${TOPOLOGY}) in;\n"
423                         "layout(points, max_vertices = 1) out;\n"
424                         "layout(binding = 3, std430) buffer Buffer3\n"
425                         "{\n"
426                         "  uint result[];\n"
427                         "} b3;\n"
428                         "\n"
429                         "void main (void)\n"
430                         "{\n"
431                         + bdyStr +
432                         "  b3.result[gl_PrimitiveIDIn] = tempResult;\n"
433                         "  gl_Position = gl_in[0].gl_Position;\n"
434                         "  EmitVertex();\n"
435                         "  EndPrimitive();\n"
436                         "}\n";
437
438                 const string fragment =
439                         "#version 450\n"
440                         "#extension GL_KHR_shader_subgroup_ballot: enable\n"
441                         "layout(location = 0) out uint result;\n"
442                         "void main (void)\n"
443                         "{\n"
444                         + bdyStr +
445                         "  result = tempResult;\n"
446                         "}\n";
447
448                 subgroups::addNoSubgroupShader(programCollection);
449
450                 programCollection.add("vert") << glu::VertexSource(vertex);
451                 programCollection.add("tesc") << glu::TessellationControlSource(tesc);
452                 programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
453                 subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
454                 programCollection.add("fragment") << glu::FragmentSource(fragment);
455         }
456 }
457
458 void supportedCheck (Context& context, CaseDefinition caseDef)
459 {
460         DE_UNREF(caseDef);
461         if (!subgroups::isSubgroupSupported(context))
462                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
463
464         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_BALLOT_BIT))
465         {
466                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
467         }
468 }
469
470 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
471 {
472         if (!subgroups::areSubgroupOperationsSupportedForStage(
473                         context, caseDef.shaderStage))
474         {
475                 if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
476                 {
477                         return tcu::TestStatus::fail(
478                                            "Shader stage " +
479                                            subgroups::getShaderStageName(caseDef.shaderStage) +
480                                            " is required to support subgroup operations!");
481                 }
482                 else
483                 {
484                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
485                 }
486         }
487
488         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
489                 return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
490         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
491                 return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
492         else if ((SHADER_STAGE_TESS_CONTROL_BIT | SHADER_STAGE_TESS_EVALUATION_BIT) & caseDef.shaderStage)
493                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
494         else
495                 TCU_THROW(InternalError, "Unhandled shader stage");
496 }
497
498 tcu::TestStatus test (Context& context, const CaseDefinition caseDef)
499 {
500         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
501         {
502                 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
503                 {
504                         return tcu::TestStatus::fail(
505                                            "Shader stage " +
506                                 subgroups::getShaderStageName(caseDef.shaderStage) +
507                                 " is required to support subgroup operations!");
508                 }
509                 return subgroups::makeComputeTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkComputeStage);
510         }
511         else
512         {
513                 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
514
515                 ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
516
517                 if ( SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
518                 {
519                         if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
520                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
521                         else
522                                 stages = SHADER_STAGE_FRAGMENT_BIT;
523                 }
524
525                 if ((ShaderStageFlags)0u == stages)
526                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
527
528                 return subgroups::allStages(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, stages);
529         }
530         return tcu::TestStatus::pass("OK");
531 }
532 }
533
534 deqp::TestCaseGroup* createSubgroupsBallotOtherTests(deqp::Context& testCtx)
535 {
536         de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
537                 testCtx, "graphics", "Subgroup ballot other category tests: graphics"));
538         de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
539                 testCtx, "compute", "Subgroup ballot other category tests: compute"));
540         de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
541                 testCtx, "framebuffer", "Subgroup ballot other category tests: framebuffer"));
542
543         const ShaderStageFlags stages[] =
544         {
545                 SHADER_STAGE_VERTEX_BIT,
546                 SHADER_STAGE_TESS_EVALUATION_BIT,
547                 SHADER_STAGE_TESS_CONTROL_BIT,
548                 SHADER_STAGE_GEOMETRY_BIT,
549         };
550
551         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
552         {
553                 const string    op              = de::toLower(getOpTypeName(opTypeIndex));
554                 {
555                         const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT};
556                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
557                 }
558
559                 {
560                         const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS};
561                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
562                 }
563
564                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
565                 {
566                         const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex]};
567                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(), op + "_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
568                 }
569         }
570
571         de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
572                 testCtx, "ballot_other", "Subgroup ballot other category tests"));
573
574         group->addChild(graphicGroup.release());
575         group->addChild(computeGroup.release());
576         group->addChild(framebufferGroup.release());
577
578         return group.release();
579 }
580
581 } // subgroups
582 } // glc