Merge remote-tracking branch 'goog/upstream-vulkan-cts-next' into vulkan-cts-1.1...
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsBuiltinVarTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBuiltinVarTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34
35 namespace vkt
36 {
37 namespace subgroups
38 {
39
40 bool checkVertexPipelineStagesSubgroupSize(std::vector<const void*> datas,
41                 deUint32 width, deUint32 subgroupSize)
42 {
43         const deUint32* data =
44                 reinterpret_cast<const deUint32*>(datas[0]);
45         for (deUint32 x = 0; x < width; ++x)
46         {
47                 deUint32 val = data[x * 4];
48
49                 if (subgroupSize != val)
50                 {
51                         return false;
52                 }
53         }
54
55         return true;
56 }
57
58 bool checkVertexPipelineStagesSubgroupInvocationID(std::vector<const void*> datas,
59                 deUint32 width, deUint32 subgroupSize)
60 {
61         const deUint32* data =
62                 reinterpret_cast<const deUint32*>(datas[0]);
63         vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
64
65         for (deUint32 x = 0; x < width; ++x)
66         {
67                 deUint32 subgroupInvocationID = data[(x * 4) + 1];
68
69                 if (subgroupInvocationID >= subgroupSize)
70                 {
71                         return false;
72                 }
73
74                 subgroupInvocationHits[subgroupInvocationID]++;
75         }
76
77         const deUint32 totalSize = width;
78
79         deUint32 totalInvocationsRun = 0;
80         for (deUint32 i = 0; i < subgroupSize; ++i)
81         {
82                 totalInvocationsRun += subgroupInvocationHits[i];
83         }
84
85         if (totalInvocationsRun != totalSize)
86         {
87                 return false;
88         }
89
90         return true;
91 }
92
93 static bool checkFragmentSubgroupSize(std::vector<const void*> datas,
94                                                                           deUint32 width, deUint32 height, deUint32 subgroupSize)
95 {
96         const deUint32* data =
97                 reinterpret_cast<const deUint32*>(datas[0]);
98         for (deUint32 x = 0; x < width; ++x)
99         {
100                 for (deUint32 y = 0; y < height; ++y)
101                 {
102                         deUint32 val = data[(x * height + y) * 4];
103
104                         if (subgroupSize != val)
105                         {
106                                 return false;
107                         }
108                 }
109         }
110
111         return true;
112 }
113
114 static bool checkFragmentSubgroupInvocationID(
115         std::vector<const void*> datas, deUint32 width, deUint32 height,
116         deUint32 subgroupSize)
117 {
118         const deUint32* data =
119                 reinterpret_cast<const deUint32*>(datas[0]);
120         vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
121
122         for (deUint32 x = 0; x < width; ++x)
123         {
124                 for (deUint32 y = 0; y < height; ++y)
125                 {
126                         deUint32 subgroupInvocationID = data[((x * height + y) * 4) + 1];
127
128                         if (subgroupInvocationID >= subgroupSize)
129                         {
130                                 return false;
131                         }
132
133                         subgroupInvocationHits[subgroupInvocationID]++;
134                 }
135         }
136
137         const deUint32 totalSize = width * height;
138
139         deUint32 totalInvocationsRun = 0;
140         for (deUint32 i = 0; i < subgroupSize; ++i)
141         {
142                 totalInvocationsRun += subgroupInvocationHits[i];
143         }
144
145         if (totalInvocationsRun != totalSize)
146         {
147                 return false;
148         }
149
150         return true;
151 }
152
153 static bool checkComputeSubgroupSize(std::vector<const void*> datas,
154                                                                          const deUint32 numWorkgroups[3], const deUint32 localSize[3],
155                                                                          deUint32 subgroupSize)
156 {
157         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
158
159         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
160         {
161                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
162                 {
163                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
164                         {
165                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
166                                 {
167                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
168                                         {
169                                                 for (deUint32 lZ = 0; lZ < localSize[2];
170                                                                 ++lZ)
171                                                 {
172                                                         const deUint32 globalInvocationX =
173                                                                 nX * localSize[0] + lX;
174                                                         const deUint32 globalInvocationY =
175                                                                 nY * localSize[1] + lY;
176                                                         const deUint32 globalInvocationZ =
177                                                                 nZ * localSize[2] + lZ;
178
179                                                         const deUint32 globalSizeX =
180                                                                 numWorkgroups[0] * localSize[0];
181                                                         const deUint32 globalSizeY =
182                                                                 numWorkgroups[1] * localSize[1];
183
184                                                         const deUint32 offset =
185                                                                 globalSizeX *
186                                                                 ((globalSizeY *
187                                                                   globalInvocationZ) +
188                                                                  globalInvocationY) +
189                                                                 globalInvocationX;
190
191                                                         if (subgroupSize != data[offset * 4])
192                                                         {
193                                                                 return false;
194                                                         }
195                                                 }
196                                         }
197                                 }
198                         }
199                 }
200         }
201
202         return true;
203 }
204
205 static bool checkComputeSubgroupInvocationID(std::vector<const void*> datas,
206                 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
207                 deUint32 subgroupSize)
208 {
209         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
210
211         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
212         {
213                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
214                 {
215                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
216                         {
217                                 const deUint32 totalLocalSize =
218                                         localSize[0] * localSize[1] * localSize[2];
219                                 vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
220
221                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
222                                 {
223                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
224                                         {
225                                                 for (deUint32 lZ = 0; lZ < localSize[2];
226                                                                 ++lZ)
227                                                 {
228                                                         const deUint32 globalInvocationX =
229                                                                 nX * localSize[0] + lX;
230                                                         const deUint32 globalInvocationY =
231                                                                 nY * localSize[1] + lY;
232                                                         const deUint32 globalInvocationZ =
233                                                                 nZ * localSize[2] + lZ;
234
235                                                         const deUint32 globalSizeX =
236                                                                 numWorkgroups[0] * localSize[0];
237                                                         const deUint32 globalSizeY =
238                                                                 numWorkgroups[1] * localSize[1];
239
240                                                         const deUint32 offset =
241                                                                 globalSizeX *
242                                                                 ((globalSizeY *
243                                                                   globalInvocationZ) +
244                                                                  globalInvocationY) +
245                                                                 globalInvocationX;
246
247                                                         deUint32 subgroupInvocationID = data[(offset * 4) + 1];
248
249                                                         if (subgroupInvocationID >= subgroupSize)
250                                                         {
251                                                                 return false;
252                                                         }
253
254                                                         subgroupInvocationHits[subgroupInvocationID]++;
255                                                 }
256                                         }
257                                 }
258
259                                 deUint32 totalInvocationsRun = 0;
260                                 for (deUint32 i = 0; i < subgroupSize; ++i)
261                                 {
262                                         totalInvocationsRun += subgroupInvocationHits[i];
263                                 }
264
265                                 if (totalInvocationsRun != totalLocalSize)
266                                 {
267                                         return false;
268                                 }
269                         }
270                 }
271         }
272
273         return true;
274 }
275
276 static bool checkComputeNumSubgroups(std::vector<const void*> datas,
277                                                                          const deUint32 numWorkgroups[3], const deUint32 localSize[3],
278                                                                          deUint32)
279 {
280         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
281
282         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
283         {
284                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
285                 {
286                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
287                         {
288                                 const deUint32 totalLocalSize =
289                                         localSize[0] * localSize[1] * localSize[2];
290
291                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
292                                 {
293                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
294                                         {
295                                                 for (deUint32 lZ = 0; lZ < localSize[2];
296                                                                 ++lZ)
297                                                 {
298                                                         const deUint32 globalInvocationX =
299                                                                 nX * localSize[0] + lX;
300                                                         const deUint32 globalInvocationY =
301                                                                 nY * localSize[1] + lY;
302                                                         const deUint32 globalInvocationZ =
303                                                                 nZ * localSize[2] + lZ;
304
305                                                         const deUint32 globalSizeX =
306                                                                 numWorkgroups[0] * localSize[0];
307                                                         const deUint32 globalSizeY =
308                                                                 numWorkgroups[1] * localSize[1];
309
310                                                         const deUint32 offset =
311                                                                 globalSizeX *
312                                                                 ((globalSizeY *
313                                                                   globalInvocationZ) +
314                                                                  globalInvocationY) +
315                                                                 globalInvocationX;
316
317                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
318
319                                                         if (numSubgroups > totalLocalSize)
320                                                         {
321                                                                 return false;
322                                                         }
323                                                 }
324                                         }
325                                 }
326                         }
327                 }
328         }
329
330         return true;
331 }
332
333 static bool checkComputeSubgroupID(std::vector<const void*> datas,
334                                                                    const deUint32 numWorkgroups[3], const deUint32 localSize[3],
335                                                                    deUint32)
336 {
337         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
338
339         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
340         {
341                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
342                 {
343                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
344                         {
345                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
346                                 {
347                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
348                                         {
349                                                 for (deUint32 lZ = 0; lZ < localSize[2];
350                                                                 ++lZ)
351                                                 {
352                                                         const deUint32 globalInvocationX =
353                                                                 nX * localSize[0] + lX;
354                                                         const deUint32 globalInvocationY =
355                                                                 nY * localSize[1] + lY;
356                                                         const deUint32 globalInvocationZ =
357                                                                 nZ * localSize[2] + lZ;
358
359                                                         const deUint32 globalSizeX =
360                                                                 numWorkgroups[0] * localSize[0];
361                                                         const deUint32 globalSizeY =
362                                                                 numWorkgroups[1] * localSize[1];
363
364                                                         const deUint32 offset =
365                                                                 globalSizeX *
366                                                                 ((globalSizeY *
367                                                                   globalInvocationZ) +
368                                                                  globalInvocationY) +
369                                                                 globalInvocationX;
370
371                                                         deUint32 numSubgroups = data[(offset * 4) + 2];
372                                                         deUint32 subgroupID = data[(offset * 4) + 3];
373
374                                                         if (subgroupID >= numSubgroups)
375                                                         {
376                                                                 return false;
377                                                         }
378                                                 }
379                                         }
380                                 }
381                         }
382                 }
383         }
384
385         return true;
386 }
387
388 namespace
389 {
390 struct CaseDefinition
391 {
392         std::string varName;
393         VkShaderStageFlags shaderStage;
394         bool noSSBO;
395 };
396 }
397
398 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
399 {
400         std::ostringstream src;
401         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
402         {
403                 src << "#version 450\n"
404                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
405                         << "layout(location = 0) out vec4 out_color;\n"
406                         << "layout(location = 0) in highp vec4 in_position;\n"
407                         << "\n"
408                         << "void main (void)\n"
409                         << "{\n"
410                         << "  out_color = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 1.0f, 1.0f);\n"
411                         << "  gl_Position = in_position;\n"
412                         << "  gl_PointSize = 1.0f;\n"
413                         << "}\n";
414
415                 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
416
417                 std::ostringstream source;
418                 source  << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
419                                 << "layout(location = 0) in vec4 in_color;\n"
420                                 << "layout(location = 0) out uvec4 out_color;\n"
421                                 << "void main()\n"
422                                 <<"{\n"
423                                 << "    out_color = uvec4(in_color);\n"
424                                 << "}\n";
425                 programCollection.glslSources.add("fragment") << glu::FragmentSource(source.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
426         }
427         else
428         {
429                 DE_FATAL("Unsupported shader stage");
430         }
431 }
432
433 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
434 {
435         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
436         {
437                 std::ostringstream src;
438
439                 src << "#version 450\n"
440                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
441                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
442                         "local_size_z_id = 2) in;\n"
443                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
444                         << "{\n"
445                         << "  uvec4 result[];\n"
446                         << "};\n"
447                         << "\n"
448                         << "void main (void)\n"
449                         << "{\n"
450                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
451                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
452                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
453                         "gl_GlobalInvocationID.x;\n"
454                         << "  result[offset] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, gl_NumSubgroups, gl_SubgroupID);\n"
455                         << "}\n";
456
457                 programCollection.glslSources.add("comp")
458                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
459         }
460         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
461         {
462                 programCollection.glslSources.add("vert")
463                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
464
465                 std::ostringstream frag;
466
467                 frag << "#version 450\n"
468                          << "#extension GL_KHR_shader_subgroup_basic: enable\n"
469                          << "layout(location = 0) out uvec4 data;\n"
470                          << "void main (void)\n"
471                          << "{\n"
472                          << "  data = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
473                          << "}\n";
474
475                 programCollection.glslSources.add("frag")
476                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
477         }
478         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
479         {
480                 std::ostringstream src;
481
482                 src << "#version 450\n"
483                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
484                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
485                         << "{\n"
486                         << "  uvec4 result[];\n"
487                         << "};\n"
488                         << "\n"
489                         << "void main (void)\n"
490                         << "{\n"
491                         << "  result[gl_VertexIndex] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
492                         << "  gl_PointSize = 1.0f;\n"
493                         << "}\n";
494
495                 programCollection.glslSources.add("vert")
496                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
497         }
498         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
499         {
500                 programCollection.glslSources.add("vert")
501                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
502
503                 std::ostringstream src;
504
505                 src << "#version 450\n"
506                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
507                         << "layout(points) in;\n"
508                         << "layout(points, max_vertices = 1) out;\n"
509                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
510                         << "{\n"
511                         << "  uvec4 result[];\n"
512                         << "};\n"
513                         << "\n"
514                         << "void main (void)\n"
515                         << "{\n"
516                         << "  result[gl_PrimitiveIDIn] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
517                         << "}\n";
518
519                 programCollection.glslSources.add("geom")
520                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
521         }
522         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
523         {
524                 programCollection.glslSources.add("vert")
525                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
526
527                 programCollection.glslSources.add("tese")
528                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
529
530                 std::ostringstream src;
531
532                 src << "#version 450\n"
533                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
534                         << "layout(vertices=1) out;\n"
535                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
536                         << "{\n"
537                         << "  uvec4 result[];\n"
538                         << "};\n"
539                         << "\n"
540                         << "void main (void)\n"
541                         << "{\n"
542                         << "  result[gl_PrimitiveID] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
543                         << "}\n";
544
545                 programCollection.glslSources.add("tesc")
546                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
547         }
548         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
549         {
550                 programCollection.glslSources.add("vert")
551                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
552
553                 programCollection.glslSources.add("tesc")
554                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
555
556                 std::ostringstream src;
557
558                 src << "#version 450\n"
559                         << "#extension GL_KHR_shader_subgroup_basic: enable\n"
560                         << "layout(isolines) in;\n"
561                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
562                         << "{\n"
563                         << "  uvec4 result[];\n"
564                         << "};\n"
565                         << "\n"
566                         << "void main (void)\n"
567                         << "{\n"
568                         << "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
569                         << "}\n";
570
571                 programCollection.glslSources.add("tese")
572                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
573         }
574         else
575         {
576                 DE_FATAL("Unsupported shader stage");
577         }
578 }
579
580 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
581 {
582         if (!subgroups::isSubgroupSupported(context))
583                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
584
585         if (!areSubgroupOperationsSupportedForStage(
586                                 context, caseDef.shaderStage))
587         {
588                 if (areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
589                 {
590                         return tcu::TestStatus::fail(
591                                            "Shader stage " + getShaderStageName(caseDef.shaderStage) +
592                                            " is required to support subgroup operations!");
593                 }
594                 else
595                 {
596                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
597                 }
598         }
599
600         //Tests which don't use the SSBO
601         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
602         {
603                 if ("gl_SubgroupSize" == caseDef.varName)
604                 {
605                         return makeVertexFrameBufferTest(
606                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
607                 }
608                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
609                 {
610                         return makeVertexFrameBufferTest(
611                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
612                 }
613         }
614
615         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
616                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
617         {
618                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
619                 {
620                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
621                 }
622         }
623
624         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
625         {
626                 if ("gl_SubgroupSize" == caseDef.varName)
627                 {
628                         return makeComputeTest(
629                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupSize);
630                 }
631                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
632                 {
633                         return makeComputeTest(
634                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupInvocationID);
635                 }
636                 else if ("gl_NumSubgroups" == caseDef.varName)
637                 {
638                         return makeComputeTest(
639                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeNumSubgroups);
640                 }
641                 else if ("gl_SubgroupID" == caseDef.varName)
642                 {
643                         return makeComputeTest(
644                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupID);
645                 }
646                 else
647                 {
648                         return tcu::TestStatus::fail(
649                                            caseDef.varName + " failed (unhandled error checking case " +
650                                            caseDef.varName + ")!");
651                 }
652         }
653         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
654         {
655                 if ("gl_SubgroupSize" == caseDef.varName)
656                 {
657                         return makeFragmentTest(
658                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupSize);
659                 }
660                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
661                 {
662                         return makeFragmentTest(
663                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupInvocationID);
664                 }
665                 else
666                 {
667                         return tcu::TestStatus::fail(
668                                            caseDef.varName + " failed (unhandled error checking case " +
669                                            caseDef.varName + ")!");
670                 }
671         }
672         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
673         {
674                 if ("gl_SubgroupSize" == caseDef.varName)
675                 {
676                         return makeVertexTest(
677                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
678                 }
679                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
680                 {
681                         return makeVertexTest(
682                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
683                 }
684                 else
685                 {
686                         return tcu::TestStatus::fail(
687                                            caseDef.varName + " failed (unhandled error checking case " +
688                                            caseDef.varName + ")!");
689                 }
690         }
691         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
692         {
693                 if ("gl_SubgroupSize" == caseDef.varName)
694                 {
695                         return makeGeometryTest(
696                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
697                 }
698                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
699                 {
700                         return makeGeometryTest(
701                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
702                 }
703                 else
704                 {
705                         return tcu::TestStatus::fail(
706                                            caseDef.varName + " failed (unhandled error checking case " +
707                                            caseDef.varName + ")!");
708                 }
709         }
710         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
711         {
712                 if ("gl_SubgroupSize" == caseDef.varName)
713                 {
714                         return makeTessellationControlTest(
715                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
716                 }
717                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
718                 {
719                         return makeTessellationControlTest(
720                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
721                 }
722                 else
723                 {
724                         return tcu::TestStatus::fail(
725                                            caseDef.varName + " failed (unhandled error checking case " +
726                                            caseDef.varName + ")!");
727                 }
728         }
729         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
730         {
731                 if ("gl_SubgroupSize" == caseDef.varName)
732                 {
733                         return makeTessellationEvaluationTest(
734                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
735                 }
736                 else if ("gl_SubgroupInvocationID" == caseDef.varName)
737                 {
738                         return makeTessellationEvaluationTest(
739                                            context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
740                 }
741                 else
742                 {
743                         return tcu::TestStatus::fail(
744                                            caseDef.varName + " failed (unhandled error checking case " +
745                                            caseDef.varName + ")!");
746                 }
747         }
748         else
749         {
750                 TCU_THROW(InternalError, "Unhandled shader stage");
751         }
752 }
753
754 tcu::TestCaseGroup* createSubgroupsBuiltinVarTests(tcu::TestContext& testCtx)
755 {
756         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
757                         testCtx, "builtin_var", "Subgroup builtin variable tests"));
758
759         const char* const all_stages_vars[] =
760         {
761                 "SubgroupSize",
762                 "SubgroupInvocationID"
763         };
764
765         const char* const compute_only_vars[] =
766         {
767                 "NumSubgroups",
768                 "SubgroupID"
769         };
770
771         const VkShaderStageFlags stages[] =
772         {
773                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
774                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
775                 VK_SHADER_STAGE_GEOMETRY_BIT,
776                 VK_SHADER_STAGE_VERTEX_BIT,
777                 VK_SHADER_STAGE_FRAGMENT_BIT,
778                 VK_SHADER_STAGE_COMPUTE_BIT,
779         };
780
781         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
782         {
783                 const VkShaderStageFlags stage = stages[stageIndex];
784
785                 for (int a = 0; a < DE_LENGTH_OF_ARRAY(all_stages_vars); ++a)
786                 {
787                         const std::string var = all_stages_vars[a];
788
789                         CaseDefinition caseDef = {"gl_" + var, stage, false};
790
791                         addFunctionCaseWithPrograms(group.get(),
792                                                                                 de::toLower(var) + "_" +
793                                                                                 getShaderStageName(stage), "",
794                                                                                 initPrograms, test, caseDef);
795
796                         if (VK_SHADER_STAGE_VERTEX_BIT == stage)
797                         {
798                                 caseDef.noSSBO = true;
799                                 addFunctionCaseWithPrograms(group.get(),
800                                                         de::toLower(var) + "_" +
801                                                         getShaderStageName(stage)+"_framebuffer", "",
802                                                         initFrameBufferPrograms, test, caseDef);
803                         }
804                 }
805         }
806
807         for (int a = 0; a < DE_LENGTH_OF_ARRAY(compute_only_vars); ++a)
808         {
809                 const VkShaderStageFlags stage = VK_SHADER_STAGE_COMPUTE_BIT;
810                 const std::string var = compute_only_vars[a];
811
812                 CaseDefinition caseDef = {"gl_" + var, stage, false};
813
814                 addFunctionCaseWithPrograms(group.get(), de::toLower(var) +
815                                                                         "_" + getShaderStageName(stage), "",
816                                                                         initPrograms, test, caseDef);
817         }
818
819         return group.release();
820 }
821
822 } // subgroups
823 } // vkt