e5aaf8a286ba549ea345947d903d30fabec22cec
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsBuiltinVarTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBuiltinVarTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34
35 namespace vkt
36 {
37 namespace subgroups
38 {
39
40 bool checkVertexPipelineStagesSubgroupSize(std::vector<const void*> datas,
41                 deUint32 width, deUint32 subgroupSize)
42 {
43         const deUint32* data =
44                 reinterpret_cast<const deUint32*>(datas[0]);
45         for (deUint32 x = 0; x < width; ++x)
46         {
47                 deUint32 val = data[x * 4];
48
49                 if (subgroupSize != val)
50                 {
51                         return false;
52                 }
53         }
54
55         return true;
56 }
57
58 bool checkVertexPipelineStagesSubgroupInvocationID(std::vector<const void*> datas,
59                 deUint32 width, deUint32 subgroupSize)
60 {
61         const deUint32* data =
62                 reinterpret_cast<const deUint32*>(datas[0]);
63         vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
64
65         for (deUint32 x = 0; x < width; ++x)
66         {
67                 deUint32 subgroupInvocationID = data[(x * 4) + 1];
68
69                 if (subgroupInvocationID >= subgroupSize)
70                 {
71                         return false;
72                 }
73
74                 subgroupInvocationHits[subgroupInvocationID]++;
75         }
76
77         const deUint32 totalSize = width;
78
79         deUint32 totalInvocationsRun = 0;
80         for (deUint32 i = 0; i < subgroupSize; ++i)
81         {
82                 totalInvocationsRun += subgroupInvocationHits[i];
83         }
84
85         if (totalInvocationsRun != totalSize)
86         {
87                 return false;
88         }
89
90         return true;
91 }
92
93 static bool checkFragmentSubgroupSize(std::vector<const void*> datas,
94                                                                           deUint32 width, deUint32 height, deUint32 subgroupSize)
95 {
96         const deUint32* data =
97                 reinterpret_cast<const deUint32*>(datas[0]);
98         for (deUint32 x = 0; x < width; ++x)
99         {
100                 for (deUint32 y = 0; y < height; ++y)
101                 {
102                         deUint32 val = data[(x * height + y) * 4];
103
104                         if (subgroupSize != val)
105                         {
106                                 return false;
107                         }
108                 }
109         }
110
111         return true;
112 }
113
114 static bool checkFragmentSubgroupInvocationID(
115         std::vector<const void*> datas, deUint32 width, deUint32 height,
116         deUint32 subgroupSize)
117 {
118         const deUint32* data =
119                 reinterpret_cast<const deUint32*>(datas[0]);
120         vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
121
122         for (deUint32 x = 0; x < width; ++x)
123         {
124                 for (deUint32 y = 0; y < height; ++y)
125                 {
126                         deUint32 subgroupInvocationID = data[((x * height + y) * 4) + 1];
127
128                         if (subgroupInvocationID >= subgroupSize)
129                         {
130                                 return false;
131                         }
132
133                         subgroupInvocationHits[subgroupInvocationID]++;
134                 }
135         }
136
137         const deUint32 totalSize = width * height;
138
139         deUint32 totalInvocationsRun = 0;
140         for (deUint32 i = 0; i < subgroupSize; ++i)
141         {
142                 totalInvocationsRun += subgroupInvocationHits[i];
143         }
144
145         if (totalInvocationsRun != totalSize)
146         {
147                 return false;
148         }
149
150         return true;
151 }
152
153 static bool checkComputeSubgroupSize(std::vector<const void*> datas,
154                                                                          const deUint32 numWorkgroups[3], const deUint32 localSize[3],
155                                                                          deUint32 subgroupSize)
156 {
157         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
158
159         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
160         {
161                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
162                 {
163                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
164                         {
165                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
166                                 {
167                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
168                                         {
169                                                 for (deUint32 lZ = 0; lZ < localSize[2];
170                                                                 ++lZ)
171                                                 {
172                                                         const deUint32 globalInvocationX =
173                                                                 nX * localSize[0] + lX;
174                                                         const deUint32 globalInvocationY =
175                                                                 nY * localSize[1] + lY;
176                                                         const deUint32 globalInvocationZ =
177                                                                 nZ * localSize[2] + lZ;
178
179                                                         const deUint32 globalSizeX =
180                                                                 numWorkgroups[0] * localSize[0];
181                                                         const deUint32 globalSizeY =
182                                                                 numWorkgroups[1] * localSize[1];
183
184                                                         const deUint32 offset =
185                                                                 globalSizeX *
186                                                                 ((globalSizeY *
187                                                                   globalInvocationZ) +
188                                                                  globalInvocationY) +
189                                                                 globalInvocationX;
190
191                                                         if (subgroupSize != data[offset * 4])
192                                                         {
193                                                                 return false;
194                                                         }
195                                                 }
196                                         }
197                                 }
198                         }
199                 }
200         }
201
202         return true;
203 }
204
205 static bool checkComputeSubgroupInvocationID(std::vector<const void*> datas,
206                 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
207                 deUint32 subgroupSize)
208 {
209         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
210
211         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
212         {
213                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
214                 {
215                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
216                         {
217                                 const deUint32 totalLocalSize =
218                                         localSize[0] * localSize[1] * localSize[2];
219                                 vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
220
221                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
222                                 {
223                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
224                                         {
225                                                 for (deUint32 lZ = 0; lZ < localSize[2];
226                                                                 ++lZ)
227                                                 {
228                                                         const deUint32 globalInvocationX =
229                                                                 nX * localSize[0] + lX;
230                                                         const deUint32 globalInvocationY =
231                                                                 nY * localSize[1] + lY;
232                                                         const deUint32 globalInvocationZ =
233                                                                 nZ * localSize[2] + lZ;
234
235                                                         const deUint32 globalSizeX =
236                                                                 numWorkgroups[0] * localSize[0];
237                                                         const deUint32 globalSizeY =
238                                                                 numWorkgroups[1] * localSize[1];
239
240                                                         const deUint32 offset =
241                                                                 globalSizeX *
242                                                                 ((globalSizeY *
243                                                                   globalInvocationZ) +
244                                                                  globalInvocationY) +
245                                                                 globalInvocationX;
246
247                                                         deUint32 subgroupInvocationID = data[(offset * 4) + 1];
248
249                                                         if (subgroupInvocationID >= subgroupSize)
250                                                         {
251                                                                 return false;
252                                                         }
253
254                                                         subgroupInvocationHits[subgroupInvocationID]++;
255                                                 }
256                                         }
257                                 }
258
259                                 deUint32 totalInvocationsRun = 0;
260                                 for (deUint32 i = 0; i < subgroupSize; ++i)
261                                 {
262                                         totalInvocationsRun += subgroupInvocationHits[i];
263                                 }
264
265                                 if (totalInvocationsRun != totalLocalSize)
266                                 {
267                                         return false;
268                                 }
269                         }
270                 }
271         }
272
273         return true;
274 }
275
276 static bool checkComputeNumSubgroups(std::vector<const void*> datas,
277                                                                          const deUint32 numWorkgroups[3], const deUint32 localSize[3],
278                                                                          deUint32)
279 {
280         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
281
282         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
283         {
284                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
285                 {
286                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
287                         {
288                                 const deUint32 totalLocalSize =
289                                         localSize[0] * localSize[1] * localSize[2];
290
291                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
292                                 {
293                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
294                                         {
295                                                 for (deUint32 lZ = 0; lZ < localSize[2];
296                                                                 ++lZ)
297                                                 {
298                                                         const deUint32 globalInvocationX =
299                                                                 nX * localSize[0] + lX;
300                                                         const deUint32 globalInvocationY =
301                                                                 nY * localSize[1] + lY;
302                                                         const deUint32 globalInvocationZ =
303                                                                 nZ * localSize[2] + lZ;
304
305                                                         const deUint32 globalSizeX =
306                                                                 numWorkgroups[0] * localSize[0];
307                                                         const deUint32 globalSizeY =
308                                                                 numWorkgroups[1] * localSize[1];
309
310                                                         const deUint32 offset =
311                                                                 globalSizeX *
312                                                                 ((globalSizeY *
313                                                                   globalInvocationZ) +
314                                                                  globalInvocationY) +
315                                                                 globalInvocationX;
316
317                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
318
319                                                         if (numSubgroups > totalLocalSize)
320                                                         {
321                                                                 return false;
322                                                         }
323                                                 }
324                                         }
325                                 }
326                         }
327                 }
328         }
329
330         return true;
331 }
332
333 static bool checkComputeSubgroupID(std::vector<const void*> datas,
334                                                                    const deUint32 numWorkgroups[3], const deUint32 localSize[3],
335                                                                    deUint32)
336 {
337         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
338
339         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
340         {
341                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
342                 {
343                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
344                         {
345                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
346                                 {
347                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
348                                         {
349                                                 for (deUint32 lZ = 0; lZ < localSize[2];
350                                                                 ++lZ)
351                                                 {
352                                                         const deUint32 globalInvocationX =
353                                                                 nX * localSize[0] + lX;
354                                                         const deUint32 globalInvocationY =
355                                                                 nY * localSize[1] + lY;
356                                                         const deUint32 globalInvocationZ =
357                                                                 nZ * localSize[2] + lZ;
358
359                                                         const deUint32 globalSizeX =
360                                                                 numWorkgroups[0] * localSize[0];
361                                                         const deUint32 globalSizeY =
362                                                                 numWorkgroups[1] * localSize[1];
363
364                                                         const deUint32 offset =
365                                                                 globalSizeX *
366                                                                 ((globalSizeY *
367                                                                   globalInvocationZ) +
368                                                                  globalInvocationY) +
369                                                                 globalInvocationX;
370
371                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
372                                                         deUint32 subgroupID = data[(offset * 4) + 3];
373
374                                                         if (subgroupID >= numSubgroups)
375                                                         {
376                                                                 return false;
377                                                         }
378                                                 }
379                                         }
380                                 }
381                         }
382                 }
383         }
384
385         return true;
386 }
387
388 namespace
389 {
390 struct CaseDefinition
391 {
392         std::string varName;
393         VkShaderStageFlags shaderStage;
394         bool noSSBO;
395 };
396 }
397
398 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
399 {
400         std::ostringstream src;
401         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
402         {
403                 src << "#version 450\n"
404                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
405                         << "layout(location = 0) out vec4 out_color;\n"
406                         << "layout(location = 0) in highp vec4 in_position;\n"
407                         << "\n"
408                         << "void main (void)\n"
409                         << "{\n"
410                         << "  out_color = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 1.0f, 1.0f);\n"
411                         << "  gl_Position = in_position;\n"
412                         << "  gl_PointSize = 1.0f;\n"
413                         << "}\n";
414
415                 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
416
417                 std::ostringstream source;
418                 source  << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
419                                 << "layout(location = 0) in vec4 in_color;\n"
420                                 << "layout(location = 0) out uvec4 out_color;\n"
421                                 << "void main()\n"
422                                 <<"{\n"
423                                 << "    out_color = uvec4(in_color);\n"
424                                 << "}\n";
425                 programCollection.glslSources.add("fragment") << glu::FragmentSource(source.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
426         }
427         else
428         {
429                 DE_FATAL("Unsupported shader stage");
430         }
431 }
432
433 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
434 {
435         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
436         {
437                 std::ostringstream src;
438
439                 src << "#version 450\n"
440                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
441                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
442                         "local_size_z_id = 2) in;\n"
443                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
444                         << "{\n"
445                         << "  uvec4 result[];\n"
446                         << "};\n"
447                         << "\n"
448                         << "void main (void)\n"
449                         << "{\n"
450                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
451                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
452                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
453                         "gl_GlobalInvocationID.x;\n"
454                         << "  result[offset] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, gl_NumSubgroups, gl_SubgroupID);\n"
455                         << "}\n";
456
457                 programCollection.glslSources.add("comp")
458                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
459         }
460         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
461         {
462                 programCollection.glslSources.add("vert")
463                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
464
465                 std::ostringstream frag;
466
467                 frag << "#version 450\n"
468                          << "#extension GL_KHR_shader_subgroup_basic: enable\n"
469                          << "layout(location = 0) out uvec4 data;\n"
470                          << "void main (void)\n"
471                          << "{\n"
472                          << "  data = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
473                          << "}\n";
474
475                 programCollection.glslSources.add("frag")
476                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
477         }
478         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
479         {
480                 std::ostringstream src;
481
482                 src << "#version 450\n"
483                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
484                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
485                         << "{\n"
486                         << "  uvec4 result[];\n"
487                         << "};\n"
488                         << "\n"
489                         << "void main (void)\n"
490                         << "{\n"
491                         << "  result[gl_VertexIndex] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
492                         << "}\n";
493
494                 programCollection.glslSources.add("vert")
495                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
496         }
497         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
498         {
499                 programCollection.glslSources.add("vert")
500                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
501
502                 std::ostringstream src;
503
504                 src << "#version 450\n"
505                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
506                         << "layout(points) in;\n"
507                         << "layout(points, max_vertices = 1) out;\n"
508                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
509                         << "{\n"
510                         << "  uvec4 result[];\n"
511                         << "};\n"
512                         << "\n"
513                         << "void main (void)\n"
514                         << "{\n"
515                         << "  result[gl_PrimitiveIDIn] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
516                         << "}\n";
517
518                 programCollection.glslSources.add("geom")
519                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
520         }
521         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
522         {
523                 programCollection.glslSources.add("vert")
524                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
525
526                 programCollection.glslSources.add("tese")
527                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
528
529                 std::ostringstream src;
530
531                 src << "#version 450\n"
532                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
533                         << "layout(vertices=1) out;\n"
534                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
535                         << "{\n"
536                         << "  uvec4 result[];\n"
537                         << "};\n"
538                         << "\n"
539                         << "void main (void)\n"
540                         << "{\n"
541                         << "  result[gl_PrimitiveID] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
542                         << "}\n";
543
544                 programCollection.glslSources.add("tesc")
545                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
546         }
547         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
548         {
549                 programCollection.glslSources.add("vert")
550                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
551
552                 programCollection.glslSources.add("tesc")
553                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
554
555                 std::ostringstream src;
556
557                 src << "#version 450\n"
558                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
559                         << "layout(isolines) in;\n"
560                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
561                         << "{\n"
562                         << "  uvec4 result[];\n"
563                         << "};\n"
564                         << "\n"
565                         << "void main (void)\n"
566                         << "{\n"
567                         << "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
568                         << "}\n";
569
570                 programCollection.glslSources.add("tese")
571                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
572         }
573         else
574         {
575                 DE_FATAL("Unsupported shader stage");
576         }
577 }
578
579 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
580 {
581         if (!subgroups::isSubgroupSupported(context))
582                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
583
584         if (!areSubgroupOperationsSupportedForStage(
585                                 context, caseDef.shaderStage))
586         {
587                 if (areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
588                 {
589                         return tcu::TestStatus::fail(
590                                            "Shader stage " + getShaderStageName(caseDef.shaderStage) +
591                                            " is required to support subgroup operations!");
592                 }
593                 else
594                 {
595                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
596                 }
597         }
598
599         //Tests which don't use the SSBO
600         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
601         {
602                 if ("gl_SubgroupSize" == caseDef.varName)
603                 {
604                         return makeVertexFrameBufferTest(
605                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
606                 }
607                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
608                 {
609                         return makeVertexFrameBufferTest(
610                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
611                 }
612         }
613
614         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
615                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
616         {
617                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
618                 {
619                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
620                 }
621         }
622
623         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
624         {
625                 if ("gl_SubgroupSize" == caseDef.varName)
626                 {
627                         return makeComputeTest(
628                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupSize);
629                 }
630                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
631                 {
632                         return makeComputeTest(
633                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupInvocationID);
634                 }
635                 else if ("gl_NumSubgroups" == caseDef.varName)
636                 {
637                         return makeComputeTest(
638                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeNumSubgroups);
639                 }
640                 else if ("gl_SubgroupID" == caseDef.varName)
641                 {
642                         return makeComputeTest(
643                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupID);
644                 }
645                 else
646                 {
647                         return tcu::TestStatus::fail(
648                                            caseDef.varName + " failed (unhandled error checking case " +
649                                            caseDef.varName + ")!");
650                 }
651         }
652         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
653         {
654                 if ("gl_SubgroupSize" == caseDef.varName)
655                 {
656                         return makeFragmentTest(
657                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupSize);
658                 }
659                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
660                 {
661                         return makeFragmentTest(
662                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupInvocationID);
663                 }
664                 else
665                 {
666                         return tcu::TestStatus::fail(
667                                            caseDef.varName + " failed (unhandled error checking case " +
668                                            caseDef.varName + ")!");
669                 }
670         }
671         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
672         {
673                 if ("gl_SubgroupSize" == caseDef.varName)
674                 {
675                         return makeVertexTest(
676                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
677                 }
678                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
679                 {
680                         return makeVertexTest(
681                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
682                 }
683                 else
684                 {
685                         return tcu::TestStatus::fail(
686                                            caseDef.varName + " failed (unhandled error checking case " +
687                                            caseDef.varName + ")!");
688                 }
689         }
690         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
691         {
692                 if ("gl_SubgroupSize" == caseDef.varName)
693                 {
694                         return makeGeometryTest(
695                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
696                 }
697                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
698                 {
699                         return makeGeometryTest(
700                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
701                 }
702                 else
703                 {
704                         return tcu::TestStatus::fail(
705                                            caseDef.varName + " failed (unhandled error checking case " +
706                                            caseDef.varName + ")!");
707                 }
708         }
709         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
710         {
711                 if ("gl_SubgroupSize" == caseDef.varName)
712                 {
713                         return makeTessellationControlTest(
714                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
715                 }
716                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
717                 {
718                         return makeTessellationControlTest(
719                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
720                 }
721                 else
722                 {
723                         return tcu::TestStatus::fail(
724                                            caseDef.varName + " failed (unhandled error checking case " +
725                                            caseDef.varName + ")!");
726                 }
727         }
728         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
729         {
730                 if ("gl_SubgroupSize" == caseDef.varName)
731                 {
732                         return makeTessellationEvaluationTest(
733                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
734                 }
735                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
736                 {
737                         return makeTessellationEvaluationTest(
738                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
739                 }
740                 else
741                 {
742                         return tcu::TestStatus::fail(
743                                            caseDef.varName + " failed (unhandled error checking case " +
744                                            caseDef.varName + ")!");
745                 }
746         }
747         else
748         {
749                 TCU_THROW(InternalError, "Unhandled shader stage");
750         }
751 }
752
753 tcu::TestCaseGroup* createSubgroupsBuiltinVarTests(tcu::TestContext& testCtx)
754 {
755         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
756                         testCtx, "builtin_var", "Subgroup builtin variable tests"));
757
758         const char* const all_stages_vars[] =
759         {
760                 "SubgroupSize",
761                 "SubgroupInvocationID"
762         };
763
764         const char* const compute_only_vars[] =
765         {
766                 "NumSubgroups",
767                 "SubgroupID"
768         };
769
770         const VkShaderStageFlags stages[] =
771         {
772                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
773                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
774                 VK_SHADER_STAGE_GEOMETRY_BIT,
775                 VK_SHADER_STAGE_VERTEX_BIT,
776                 VK_SHADER_STAGE_FRAGMENT_BIT,
777                 VK_SHADER_STAGE_COMPUTE_BIT,
778         };
779
780         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
781         {
782                 const VkShaderStageFlags stage = stages[stageIndex];
783
784                 for (int a = 0; a < DE_LENGTH_OF_ARRAY(all_stages_vars); ++a)
785                 {
786                         const std::string var = all_stages_vars[a];
787
788                         CaseDefinition caseDef = {"gl_" + var, stage, false};
789
790                         addFunctionCaseWithPrograms(group.get(),
791                                                                                 de::toLower(var) + "_" +
792                                                                                 getShaderStageName(stage), "",
793                                                                                 initPrograms, test, caseDef);
794
795                         if (VK_SHADER_STAGE_VERTEX_BIT == stage)
796                         {
797                                 caseDef.noSSBO = true;
798                                 addFunctionCaseWithPrograms(group.get(),
799                                                         de::toLower(var) + "_" +
800                                                         getShaderStageName(stage)+"_framebuffer", "",
801                                                         initFrameBufferPrograms, test, caseDef);
802                         }
803                 }
804         }
805
806         for (int a = 0; a < DE_LENGTH_OF_ARRAY(compute_only_vars); ++a)
807         {
808                 const VkShaderStageFlags stage = VK_SHADER_STAGE_COMPUTE_BIT;
809                 const std::string var = compute_only_vars[a];
810
811                 CaseDefinition caseDef = {"gl_" + var, stage, false};
812
813                 addFunctionCaseWithPrograms(group.get(), de::toLower(var) +
814                                                                         "_" + getShaderStageName(stage), "",
815                                                                         initPrograms, test, caseDef);
816         }
817
818         return group.release();
819 }
820
821 } // subgroups
822 } // vkt