porting changes for OpenGL Subgroup tests
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsBuiltinVarTests.cpp
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsBuiltinVarTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39
40 bool checkVertexPipelineStagesSubgroupSize(std::vector<const void*> datas,
41                 deUint32 width, deUint32 subgroupSize)
42 {
43         const deUint32* data =
44                 reinterpret_cast<const deUint32*>(datas[0]);
45         for (deUint32 x = 0; x < width; ++x)
46         {
47                 deUint32 val = data[x * 4];
48
49                 if (subgroupSize != val)
50                         return false;
51         }
52
53         return true;
54 }
55
56 bool checkVertexPipelineStagesSubgroupInvocationID(std::vector<const void*> datas,
57                 deUint32 width, deUint32 subgroupSize)
58 {
59         const deUint32* data =
60                 reinterpret_cast<const deUint32*>(datas[0]);
61         vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
62
63         for (deUint32 x = 0; x < width; ++x)
64         {
65                 deUint32 subgroupInvocationID = data[(x * 4) + 1];
66
67                 if (subgroupInvocationID >= subgroupSize)
68                         return false;
69                 subgroupInvocationHits[subgroupInvocationID]++;
70         }
71
72         const deUint32 totalSize = width;
73
74         deUint32 totalInvocationsRun = 0;
75         for (deUint32 i = 0; i < subgroupSize; ++i)
76         {
77                 totalInvocationsRun += subgroupInvocationHits[i];
78         }
79
80         if (totalInvocationsRun != totalSize)
81                 return false;
82
83         return true;
84 }
85
86 static bool checkComputeSubgroupSize(std::vector<const void*> datas,
87                                                                          const deUint32 numWorkgroups[3], const deUint32 localSize[3],
88                                                                          deUint32 subgroupSize)
89 {
90         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
91
92         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
93         {
94                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
95                 {
96                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
97                         {
98                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
99                                 {
100                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
101                                         {
102                                                 for (deUint32 lZ = 0; lZ < localSize[2];
103                                                                 ++lZ)
104                                                 {
105                                                         const deUint32 globalInvocationX =
106                                                                 nX * localSize[0] + lX;
107                                                         const deUint32 globalInvocationY =
108                                                                 nY * localSize[1] + lY;
109                                                         const deUint32 globalInvocationZ =
110                                                                 nZ * localSize[2] + lZ;
111
112                                                         const deUint32 globalSizeX =
113                                                                 numWorkgroups[0] * localSize[0];
114                                                         const deUint32 globalSizeY =
115                                                                 numWorkgroups[1] * localSize[1];
116
117                                                         const deUint32 offset =
118                                                                 globalSizeX *
119                                                                 ((globalSizeY *
120                                                                   globalInvocationZ) +
121                                                                  globalInvocationY) +
122                                                                 globalInvocationX;
123
124                                                         if (subgroupSize != data[offset * 4])
125                                                                 return false;
126                                                 }
127                                         }
128                                 }
129                         }
130                 }
131         }
132
133         return true;
134 }
135
136 static bool checkComputeSubgroupInvocationID(std::vector<const void*> datas,
137                 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
138                 deUint32 subgroupSize)
139 {
140         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
141
142         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
143         {
144                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
145                 {
146                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
147                         {
148                                 const deUint32 totalLocalSize =
149                                         localSize[0] * localSize[1] * localSize[2];
150                                 vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
151
152                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
153                                 {
154                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
155                                         {
156                                                 for (deUint32 lZ = 0; lZ < localSize[2];
157                                                                 ++lZ)
158                                                 {
159                                                         const deUint32 globalInvocationX =
160                                                                 nX * localSize[0] + lX;
161                                                         const deUint32 globalInvocationY =
162                                                                 nY * localSize[1] + lY;
163                                                         const deUint32 globalInvocationZ =
164                                                                 nZ * localSize[2] + lZ;
165
166                                                         const deUint32 globalSizeX =
167                                                                 numWorkgroups[0] * localSize[0];
168                                                         const deUint32 globalSizeY =
169                                                                 numWorkgroups[1] * localSize[1];
170
171                                                         const deUint32 offset =
172                                                                 globalSizeX *
173                                                                 ((globalSizeY *
174                                                                   globalInvocationZ) +
175                                                                  globalInvocationY) +
176                                                                 globalInvocationX;
177
178                                                         deUint32 subgroupInvocationID = data[(offset * 4) + 1];
179
180                                                         if (subgroupInvocationID >= subgroupSize)
181                                                                 return false;
182
183                                                         subgroupInvocationHits[subgroupInvocationID]++;
184                                                 }
185                                         }
186                                 }
187
188                                 deUint32 totalInvocationsRun = 0;
189                                 for (deUint32 i = 0; i < subgroupSize; ++i)
190                                 {
191                                         totalInvocationsRun += subgroupInvocationHits[i];
192                                 }
193
194                                 if (totalInvocationsRun != totalLocalSize)
195                                         return false;
196                         }
197                 }
198         }
199
200         return true;
201 }
202
203 static bool checkComputeNumSubgroups    (std::vector<const void*>       datas,
204                                                                                 const deUint32                          numWorkgroups[3],
205                                                                                 const deUint32                          localSize[3],
206                                                                                 deUint32)
207 {
208         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
209
210         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
211         {
212                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
213                 {
214                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
215                         {
216                                 const deUint32 totalLocalSize =
217                                         localSize[0] * localSize[1] * localSize[2];
218
219                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
220                                 {
221                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
222                                         {
223                                                 for (deUint32 lZ = 0; lZ < localSize[2];
224                                                                 ++lZ)
225                                                 {
226                                                         const deUint32 globalInvocationX =
227                                                                 nX * localSize[0] + lX;
228                                                         const deUint32 globalInvocationY =
229                                                                 nY * localSize[1] + lY;
230                                                         const deUint32 globalInvocationZ =
231                                                                 nZ * localSize[2] + lZ;
232
233                                                         const deUint32 globalSizeX =
234                                                                 numWorkgroups[0] * localSize[0];
235                                                         const deUint32 globalSizeY =
236                                                                 numWorkgroups[1] * localSize[1];
237
238                                                         const deUint32 offset =
239                                                                 globalSizeX *
240                                                                 ((globalSizeY *
241                                                                   globalInvocationZ) +
242                                                                  globalInvocationY) +
243                                                                 globalInvocationX;
244
245                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
246
247                                                         if (numSubgroups > totalLocalSize)
248                                                                 return false;
249                                                 }
250                                         }
251                                 }
252                         }
253                 }
254         }
255
256         return true;
257 }
258
259 static bool checkComputeSubgroupID      (std::vector<const void*>       datas,
260                                                                         const deUint32                          numWorkgroups[3],
261                                                                         const deUint32                          localSize[3],
262                                                                         deUint32)
263 {
264         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
265
266         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
267         {
268                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
269                 {
270                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
271                         {
272                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
273                                 {
274                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
275                                         {
276                                                 for (deUint32 lZ = 0; lZ < localSize[2];
277                                                                 ++lZ)
278                                                 {
279                                                         const deUint32 globalInvocationX =
280                                                                 nX * localSize[0] + lX;
281                                                         const deUint32 globalInvocationY =
282                                                                 nY * localSize[1] + lY;
283                                                         const deUint32 globalInvocationZ =
284                                                                 nZ * localSize[2] + lZ;
285
286                                                         const deUint32 globalSizeX =
287                                                                 numWorkgroups[0] * localSize[0];
288                                                         const deUint32 globalSizeY =
289                                                                 numWorkgroups[1] * localSize[1];
290
291                                                         const deUint32 offset =
292                                                                 globalSizeX *
293                                                                 ((globalSizeY *
294                                                                   globalInvocationZ) +
295                                                                  globalInvocationY) +
296                                                                 globalInvocationX;
297
298                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
299                                                         deUint32 subgroupID = data[(offset * 4) + 3];
300
301                                                         if (subgroupID >= numSubgroups)
302                                                                 return false;
303                                                 }
304                                         }
305                                 }
306                         }
307                 }
308         }
309
310         return true;
311 }
312
313 namespace
314 {
315 struct CaseDefinition
316 {
317         std::string varName;
318         ShaderStageFlags shaderStage;
319 };
320 }
321
322 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
323 {
324         {
325                 const string fragmentGLSL =
326                     "#version 450\n"
327                         "layout(location = 0) in vec4 in_color;\n"
328                         "layout(location = 0) out uvec4 out_color;\n"
329                         "void main()\n"
330                         "{\n"
331                          "      out_color = uvec4(in_color);\n"
332                          "}\n";
333                 programCollection.add("fragment") << glu::FragmentSource(fragmentGLSL);
334         }
335
336         if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
337                 subgroups::setVertexShaderFrameBuffer(programCollection);
338
339         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
340         {
341                 const string vertexGLSL =
342                         "#version 450\n"
343                         "#extension GL_KHR_shader_subgroup_basic: enable\n"
344                         "layout(location = 0) out vec4 out_color;\n"
345                         "layout(location = 0) in highp vec4 in_position;\n"
346                         "\n"
347                         "void main (void)\n"
348                         "{\n"
349                         "  out_color = vec4(gl_SubgroupSize, gl_SubgroupInvocationID, 1.0f, 1.0f);\n"
350                         "  gl_Position = in_position;\n"
351                         "  gl_PointSize = 1.0f;\n"
352                         "}\n";
353                 programCollection.add("vert") << glu::VertexSource(vertexGLSL);
354         }
355         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
356         {
357                 const string controlSourceGLSL =
358                         "#version 450\n"
359                         "#extension GL_EXT_tessellation_shader : require\n"
360                         "layout(vertices = 2) out;\n"
361                         "layout(location = 0) out vec4 out_color[];\n"
362                         "void main (void)\n"
363                         "{\n"
364                         "  if (gl_InvocationID == 0)\n"
365                         "  {\n"
366                         "    gl_TessLevelOuter[0] = 1.0f;\n"
367                         "    gl_TessLevelOuter[1] = 1.0f;\n"
368                         "  }\n"
369                         "  out_color[gl_InvocationID] = vec4(0.0f);\n"
370                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
371                         "}\n";
372                 programCollection.add("tesc") << glu::TessellationControlSource(controlSourceGLSL);
373
374                 const string evaluationSourceGLSL =
375                         "#version 450\n"
376                         "#extension GL_KHR_shader_subgroup_basic: enable\n"
377                         "#extension GL_EXT_tessellation_shader : require\n"
378                         "layout(isolines, equal_spacing, ccw ) in;\n"
379                         "layout(location = 0) in vec4 in_color[];\n"
380                         "layout(location = 0) out vec4 out_color;\n"
381                         "\n"
382                         "void main (void)\n"
383                         "{\n"
384                         "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
385                         "  out_color = vec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0.0f, 0.0f);\n"
386                         "}\n";
387                 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSourceGLSL);
388         }
389         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
390         {
391                 const string controlSourceGLSL =
392                         "#version 450\n"
393                         "#extension GL_EXT_tessellation_shader : require\n"
394                         "#extension GL_KHR_shader_subgroup_basic: enable\n"
395                         "layout(vertices = 2) out;\n"
396                         "layout(location = 0) out vec4 out_color[];\n"
397                         "void main (void)\n"
398                         "{\n"
399                         "  if (gl_InvocationID == 0)\n"
400                         "  {\n"
401                         "    gl_TessLevelOuter[0] = 1.0f;\n"
402                         "    gl_TessLevelOuter[1] = 1.0f;\n"
403                         "  }\n"
404                         "  out_color[gl_InvocationID] = vec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
405                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
406                         "}\n";
407                 programCollection.add("tesc") << glu::TessellationControlSource(controlSourceGLSL);
408
409                 const string  evaluationSourceGLSL =
410                         "#version 450\n"
411                         "#extension GL_KHR_shader_subgroup_basic: enable\n"
412                         "#extension GL_EXT_tessellation_shader : require\n"
413                         "layout(isolines, equal_spacing, ccw ) in;\n"
414                         "layout(location = 0) in vec4 in_color[];\n"
415                         "layout(location = 0) out vec4 out_color;\n"
416                         "\n"
417                         "void main (void)\n"
418                         "{\n"
419                         "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
420                         "  out_color = in_color[0];\n"
421                         "}\n";
422                 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSourceGLSL);
423         }
424         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
425         {
426                 const string geometryGLSL =
427                         "#version 450\n"
428                         "#extension GL_KHR_shader_subgroup_basic: enable\n"
429                         "layout(points) in;\n"
430                         "layout(points, max_vertices = 1) out;\n"
431                         "layout(location = 0) out vec4 out_color;\n"
432                         "void main (void)\n"
433                         "{\n"
434                         "  out_color = vec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
435                         "  gl_Position = gl_in[0].gl_Position;\n"
436                         "  EmitVertex();\n"
437                         "  EndPrimitive();\n"
438                         "}\n";
439                 programCollection.add("geometry") << glu::GeometrySource(geometryGLSL);
440         }
441         else
442         {
443                 DE_FATAL("Unsupported shader stage");
444         }
445 }
446
447 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
448 {
449         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
450         {
451                 std::ostringstream src;
452
453                 src << "#version 450\n"
454                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
455                         << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
456                         << "layout(binding = 0, std430) buffer Output\n"
457                         << "{\n"
458                         << "  uvec4 result[];\n"
459                         << "};\n"
460                         << "\n"
461                         << "void main (void)\n"
462                         << "{\n"
463                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
464                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
465                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
466                         "gl_GlobalInvocationID.x;\n"
467                         << "  result[offset] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, gl_NumSubgroups, gl_SubgroupID);\n"
468                         << "}\n";
469
470                 programCollection.add("comp") << glu::ComputeSource(src.str());
471         }
472         else
473         {
474                 {
475                         const string vertexGLSL =
476                                 "#version 450\n"
477                                 "#extension GL_KHR_shader_subgroup_basic: enable\n"
478                                 "layout(binding = 0, std430) buffer Output0\n"
479                                 "{\n"
480                                 "  uvec4 result[];\n"
481                                 "} b0;\n"
482                                 "\n"
483                                 "void main (void)\n"
484                                 "{\n"
485                                 "  b0.result[gl_VertexID] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
486                                 "  float pixelSize = 2.0f/1024.0f;\n"
487                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
488                                 "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
489                                 "  gl_PointSize = 1.0f;\n"
490                                 "}\n";
491                         programCollection.add("vert") << glu::VertexSource(vertexGLSL);
492                 }
493
494                 {
495                         const string tescGLSL =
496                                 "#version 450\n"
497                                 "#extension GL_KHR_shader_subgroup_basic: enable\n"
498                                 "layout(vertices=1) out;\n"
499                                 "layout(binding = 1, std430) buffer Output1\n"
500                                 "{\n"
501                                 "  uvec4 result[];\n"
502                                 "} b1;\n"
503                                 "\n"
504                                 "void main (void)\n"
505                                 "{\n"
506                                 "  b1.result[gl_PrimitiveID] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
507                                 "  if (gl_InvocationID == 0)\n"
508                                 "  {\n"
509                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
510                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
511                                 "  }\n"
512                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
513                                 "}\n";
514                         programCollection.add("tesc") << glu::TessellationControlSource(tescGLSL);
515                 }
516
517                 {
518                         const string teseGLSL =
519                                 "#version 450\n"
520                                 "#extension GL_KHR_shader_subgroup_basic: enable\n"
521                                 "layout(isolines) in;\n"
522                                 "layout(binding = 2, std430) buffer Output2\n"
523                                 "{\n"
524                                 "  uvec4 result[];\n"
525                                 "} b2;\n"
526                                 "\n"
527                                 "void main (void)\n"
528                                 "{\n"
529                                 "  b2.result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
530                                 "  float pixelSize = 2.0f/1024.0f;\n"
531                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
532                                 "}\n";
533                         programCollection.add("tese") << glu::TessellationEvaluationSource(teseGLSL);
534                 }
535
536                 {
537                         const string geometryGLSL =
538                                 "#version 450\n"
539                                 "#extension GL_KHR_shader_subgroup_basic: enable\n"
540                                 "layout(${TOPOLOGY}) in;\n"
541                                 "layout(points, max_vertices = 1) out;\n"
542                                 "layout(binding = 3, std430) buffer Output3\n"
543                                 "{\n"
544                                 "  uvec4 result[];\n"
545                                 "} b3;\n"
546                                 "\n"
547                                 "void main (void)\n"
548                                 "{\n"
549                                 "  b3.result[gl_PrimitiveIDIn] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
550                                 "  gl_Position = gl_in[0].gl_Position;\n"
551                                 "  EmitVertex();\n"
552                                 "  EndPrimitive();\n"
553                                 "}\n";
554                         addGeometryShadersFromTemplate(geometryGLSL, programCollection);
555                 }
556
557                 {
558                         const string fragmentGLSL =
559                                 "#version 450\n"
560                                 "#extension GL_KHR_shader_subgroup_basic: enable\n"
561                                 "layout(location = 0) out uvec4 data;\n"
562                                 "void main (void)\n"
563                                 "{\n"
564                                 "  data = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
565                                 "}\n";
566                         programCollection.add("fragment") << glu::FragmentSource(fragmentGLSL);
567                 }
568
569                 subgroups::addNoSubgroupShader(programCollection);
570         }
571 }
572
573 void supportedCheck (Context& context, CaseDefinition caseDef)
574 {
575         DE_UNREF(caseDef);
576         if (!subgroups::isSubgroupSupported(context))
577                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
578 }
579
580 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
581 {
582         if (!areSubgroupOperationsSupportedForStage(
583                                 context, caseDef.shaderStage))
584         {
585                 if (areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
586                 {
587                         return tcu::TestStatus::fail(
588                                            "Shader stage " + getShaderStageName(caseDef.shaderStage) +
589                                            " is required to support subgroup operations!");
590                 }
591                 else
592                 {
593                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
594                 }
595         }
596
597         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
598         {
599                 if ("gl_SubgroupSize" == caseDef.varName)
600                 {
601                         return makeVertexFrameBufferTest(
602                                            context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
603                 }
604                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
605                 {
606                         return makeVertexFrameBufferTest(
607                                            context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
608                 }
609                 else
610                 {
611                         return tcu::TestStatus::fail(
612                                            caseDef.varName + " failed (unhandled error checking case " +
613                                            caseDef.varName + ")!");
614                 }
615         }
616         else if ((SHADER_STAGE_TESS_EVALUATION_BIT | SHADER_STAGE_TESS_CONTROL_BIT) & caseDef.shaderStage )
617         {
618                 if ("gl_SubgroupSize" == caseDef.varName)
619                 {
620                         return makeTessellationEvaluationFrameBufferTest(
621                                         context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
622                 }
623                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
624                 {
625                         return makeTessellationEvaluationFrameBufferTest(
626                                         context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
627                 }
628                 else
629                 {
630                         return tcu::TestStatus::fail(
631                                         caseDef.varName + " failed (unhandled error checking case " +
632                                         caseDef.varName + ")!");
633                 }
634         }
635         else if (SHADER_STAGE_GEOMETRY_BIT & caseDef.shaderStage )
636         {
637                 if ("gl_SubgroupSize" == caseDef.varName)
638                 {
639                         return makeGeometryFrameBufferTest(
640                                         context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
641                 }
642                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
643                 {
644                         return makeGeometryFrameBufferTest(
645                                         context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
646                 }
647                 else
648                 {
649                         return tcu::TestStatus::fail(
650                                         caseDef.varName + " failed (unhandled error checking case " +
651                                         caseDef.varName + ")!");
652                 }
653         }
654         else
655         {
656                 TCU_THROW(InternalError, "Unhandled shader stage");
657         }
658 }
659
660
661 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
662 {
663         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
664         {
665                 if (!areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
666                 {
667                         return tcu::TestStatus::fail(
668                                            "Shader stage " + getShaderStageName(caseDef.shaderStage) +
669                                            " is required to support subgroup operations!");
670                 }
671
672                 if ("gl_SubgroupSize" == caseDef.varName)
673                 {
674                         return makeComputeTest(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupSize);
675                 }
676                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
677                 {
678                         return makeComputeTest(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupInvocationID);
679                 }
680                 else if ("gl_NumSubgroups" == caseDef.varName)
681                 {
682                         return makeComputeTest(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeNumSubgroups);
683                 }
684                 else if ("gl_SubgroupID" == caseDef.varName)
685                 {
686                         return makeComputeTest(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupID);
687                 }
688                 else
689                 {
690                         return tcu::TestStatus::fail(
691                                         caseDef.varName + " failed (unhandled error checking case " +
692                                         caseDef.varName + ")!");
693                 }
694         }
695         else
696         {
697                 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
698
699                 subgroups::ShaderStageFlags stages = (subgroups::ShaderStageFlags)(caseDef.shaderStage & supportedStages);
700
701                 if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
702                 {
703                         if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
704                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
705                         else
706                                 stages = SHADER_STAGE_FRAGMENT_BIT;
707                 }
708
709                 if ((ShaderStageFlags)0u == stages)
710                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
711
712                 if ("gl_SubgroupSize" == caseDef.varName)
713                 {
714                         return subgroups::allStages(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize, stages);
715                 }
716                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
717                 {
718                         return subgroups::allStages(context, FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID, stages);
719                 }
720                 else
721                 {
722                         return tcu::TestStatus::fail(
723                                            caseDef.varName + " failed (unhandled error checking case " +
724                                            caseDef.varName + ")!");
725                 }
726         }
727 }
728
729 deqp::TestCaseGroup* createSubgroupsBuiltinVarTests(deqp::Context& testCtx)
730 {
731         de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
732                 testCtx, "graphics", "Subgroup builtin variable tests: graphics"));
733         de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
734                 testCtx, "compute", "Subgroup builtin variable tests: compute"));
735         de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
736                 testCtx, "framebuffer", "Subgroup builtin variable tests: framebuffer"));
737
738         const char* const all_stages_vars[] =
739         {
740                 "SubgroupSize",
741                 "SubgroupInvocationID"
742         };
743
744         const char* const compute_only_vars[] =
745         {
746                 "NumSubgroups",
747                 "SubgroupID"
748         };
749
750         const ShaderStageFlags stages[] =
751         {
752                 SHADER_STAGE_VERTEX_BIT,
753                 SHADER_STAGE_TESS_EVALUATION_BIT,
754                 SHADER_STAGE_TESS_CONTROL_BIT,
755                 SHADER_STAGE_GEOMETRY_BIT,
756         };
757
758         for (int a = 0; a < DE_LENGTH_OF_ARRAY(all_stages_vars); ++a)
759         {
760                 const std::string var = all_stages_vars[a];
761                 const std::string varLower = de::toLower(var);
762
763                 {
764                         const CaseDefinition caseDef = { "gl_" + var, SHADER_STAGE_ALL_GRAPHICS};
765
766                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(),
767                                                                                 varLower, "",
768                                                                                 supportedCheck, initPrograms, test, caseDef);
769                 }
770
771                 {
772                         const CaseDefinition caseDef = {"gl_" + var, SHADER_STAGE_COMPUTE_BIT};
773                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(),
774                                                 varLower + "_" + getShaderStageName(caseDef.shaderStage), "",
775                                                 supportedCheck, initPrograms, test, caseDef);
776                 }
777
778                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
779                 {
780                         const CaseDefinition caseDef = {"gl_" + var, stages[stageIndex]};
781                         SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(),
782                                                 varLower + "_" + getShaderStageName(caseDef.shaderStage), "",
783                                                 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
784                 }
785         }
786
787         for (int a = 0; a < DE_LENGTH_OF_ARRAY(compute_only_vars); ++a)
788         {
789                 const std::string var = compute_only_vars[a];
790
791                 const CaseDefinition caseDef = {"gl_" + var, SHADER_STAGE_COMPUTE_BIT};
792
793                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), de::toLower(var), "",
794                                                                         supportedCheck, initPrograms, test, caseDef);
795         }
796
797         de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
798                 testCtx, "builtin_var", "Subgroup builtin variable tests"));
799
800         group->addChild(graphicGroup.release());
801         group->addChild(computeGroup.release());
802         group->addChild(framebufferGroup.release());
803
804         return group.release();
805 }
806
807 } // subgroups
808 } // glc