3a0ad0e017593e438331751723b3dc867ffb3e53
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / subgroups / vktSubgroupsBuiltinMaskVarTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsBuiltinMaskVarTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34
35 namespace vkt
36 {
37 namespace subgroups
38 {
39
40 static bool checkVertexPipelineStages(std::vector<const void*> datas,
41                                                                           deUint32 width, deUint32)
42 {
43         const deUint32* data =
44                 reinterpret_cast<const deUint32*>(datas[0]);
45         for (deUint32 x = 0; x < width; ++x)
46         {
47                 deUint32 val = data[x];
48
49                 if (0x1 != val)
50                 {
51                         return false;
52                 }
53         }
54
55         return true;
56 }
57
58 static bool checkFragment(std::vector<const void*> datas,
59                                                   deUint32 width, deUint32 height, deUint32)
60 {
61         const deUint32* data =
62                 reinterpret_cast<const deUint32*>(datas[0]);
63         for (deUint32 x = 0; x < width; ++x)
64         {
65                 for (deUint32 y = 0; y < height; ++y)
66                 {
67                         deUint32 val = data[(x * height + y)];
68
69                         if (0x1 != val)
70                         {
71                                 return false;
72                         }
73                 }
74         }
75
76         return true;
77 }
78
79 static bool checkCompute(std::vector<const void*> datas,
80                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
81                                                  deUint32)
82 {
83         const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
84
85         for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
86         {
87                 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
88                 {
89                         for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
90                         {
91                                 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
92                                 {
93                                         for (deUint32 lY = 0; lY < localSize[1]; ++lY)
94                                         {
95                                                 for (deUint32 lZ = 0; lZ < localSize[2];
96                                                                 ++lZ)
97                                                 {
98                                                         const deUint32 globalInvocationX =
99                                                                 nX * localSize[0] + lX;
100                                                         const deUint32 globalInvocationY =
101                                                                 nY * localSize[1] + lY;
102                                                         const deUint32 globalInvocationZ =
103                                                                 nZ * localSize[2] + lZ;
104
105                                                         const deUint32 globalSizeX =
106                                                                 numWorkgroups[0] * localSize[0];
107                                                         const deUint32 globalSizeY =
108                                                                 numWorkgroups[1] * localSize[1];
109
110                                                         const deUint32 offset =
111                                                                 globalSizeX *
112                                                                 ((globalSizeY *
113                                                                   globalInvocationZ) +
114                                                                  globalInvocationY) +
115                                                                 globalInvocationX;
116
117                                                         if (0x1 != data[offset])
118                                                         {
119                                                                 return false;
120                                                         }
121                                                 }
122                                         }
123                                 }
124                         }
125                 }
126         }
127
128         return true;
129 }
130
131 namespace
132 {
133 struct CaseDefinition
134 {
135         std::string                     varName;
136         VkShaderStageFlags      shaderStage;
137         bool                            noSSBO;
138 };
139 }
140
141 std::string subgroupMask (const CaseDefinition& caseDef)
142 {
143         std::ostringstream bdy;
144
145         bdy << "  uint tempResult = 0x1;\n"
146                 << "  uvec4 mask = subgroupBallot(true);\n"
147                 << "  const uvec4 var = " << caseDef.varName << ";\n"
148                 << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
149                 << "  {\n";
150
151         if ("gl_SubgroupEqMask" == caseDef.varName)
152         {
153                 bdy << "    if ((i == gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
154                         << "    {\n"
155                         << "      tempResult = 0;\n"
156                         << "    }\n";
157         }
158         else if ("gl_SubgroupGeMask" == caseDef.varName)
159         {
160                 bdy << "    if ((i >= gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
161                         << "    {\n"
162                         << "      tempResult = 0;\n"
163                         << "    }\n";
164         }
165         else if ("gl_SubgroupGtMask" == caseDef.varName)
166         {
167                 bdy << "    if ((i > gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
168                         << "    {\n"
169                         << "      tempResult = 0;\n"
170                         << "    }\n";
171         }
172         else if ("gl_SubgroupLeMask" == caseDef.varName)
173         {
174                 bdy << "    if ((i <= gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
175                         << "    {\n"
176                         << "      tempResult = 0;\n"
177                         << "    }\n";
178         }
179         else if ("gl_SubgroupLtMask" == caseDef.varName)
180         {
181                 bdy << "    if ((i < gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
182                         << "    {\n"
183                         << "      tempResult = 0;\n"
184                         << "    }\n";
185         }
186
187         bdy << "  }\n";
188         return bdy.str();
189 }
190
191 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
192 {
193         std::ostringstream      vertexSrc;
194         std::ostringstream      fragmentSrc;
195         std::ostringstream      bdy;
196         bdy << subgroupMask(caseDef);
197
198         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
199         {
200                 vertexSrc       << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
201                                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
202                                         << "layout(location = 0) out float out_color;\n"
203                                         << "layout(location = 0) in highp vec4 in_position;\n"
204                                         << "\n"
205                                         << "void main (void)\n"
206                                         << "{\n"
207                                         << bdy.str()
208                                         << "  out_color = float(tempResult);\n"
209                                         << "  gl_Position = in_position;\n"
210                                         << "  gl_PointSize = 1.0f;\n"
211                                         << "}\n";
212                 programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
213
214                 fragmentSrc     << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
215                                         << "layout(location = 0) in highp float in_color;\n"
216                                         << "layout(location = 0) out uint out_color;\n"
217                                         << "void main()\n"
218                                         <<"{\n"
219                                         << "    out_color = uint(in_color);\n"
220                                         << "}\n";
221                 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
222         }
223         else
224         {
225                 DE_FATAL("Unsupported shader stage");
226         }
227 }
228
229 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
230 {
231         std::ostringstream bdy;
232         bdy << subgroupMask(caseDef);
233
234         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
235         {
236                 std::ostringstream src;
237
238                 src << "#version 450\n"
239                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
240                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
241                         "local_size_z_id = 2) in;\n"
242                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
243                         << "{\n"
244                         << "  uint result[];\n"
245                         << "};\n"
246                         << "\n"
247                         << "void main (void)\n"
248                         << "{\n"
249                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
250                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
251                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
252                         "gl_GlobalInvocationID.x;\n"
253                         << bdy.str()
254                         << "  result[offset] = tempResult;\n"
255                         << "}\n";
256
257                 programCollection.glslSources.add("comp")
258                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
259         }
260         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
261         {
262                 programCollection.glslSources.add("vert")
263                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
264
265                 std::ostringstream frag;
266
267                 frag << "#version 450\n"
268                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
269                          << "layout(location = 0) out uint result;\n"
270                          << "void main (void)\n"
271                          << "{\n"
272                          << bdy.str()
273                          << "  result = tempResult;\n"
274                          << "}\n";
275
276                 programCollection.glslSources.add("frag")
277                                 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
278         }
279         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
280         {
281                 std::ostringstream src;
282
283                 src << "#version 450\n"
284                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
285                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
286                         << "{\n"
287                         << "  uint result[];\n"
288                         << "};\n"
289                         << "\n"
290                         << "void main (void)\n"
291                         << "{\n"
292                         << bdy.str()
293                         << "  result[gl_VertexIndex] = tempResult;\n"
294                         << "}\n";
295
296                 programCollection.glslSources.add("vert")
297                                 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
298         }
299         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
300         {
301                 programCollection.glslSources.add("vert")
302                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
303
304                 std::ostringstream src;
305
306                 src << "#version 450\n"
307                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
308                         << "layout(points) in;\n"
309                         << "layout(points, max_vertices = 1) out;\n"
310                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
311                         << "{\n"
312                         << "  uint result[];\n"
313                         << "};\n"
314                         << "\n"
315                         << "void main (void)\n"
316                         << "{\n"
317                         << bdy.str()
318                         << "  result[gl_PrimitiveIDIn] = tempResult;\n"
319                         << "}\n";
320
321                 programCollection.glslSources.add("geom")
322                                 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
323         }
324         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
325         {
326                 programCollection.glslSources.add("vert")
327                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
328
329                 programCollection.glslSources.add("tese")
330                                 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
331
332                 std::ostringstream src;
333
334                 src << "#version 450\n"
335                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
336                         << "layout(vertices=1) out;\n"
337                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
338                         << "{\n"
339                         << "  uint result[];\n"
340                         << "};\n"
341                         << "\n"
342                         << "void main (void)\n"
343                         << "{\n"
344                         << bdy.str()
345                         << "  result[gl_PrimitiveID] = tempResult;\n"
346                         << "}\n";
347
348                 programCollection.glslSources.add("tesc")
349                                 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
350         }
351         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
352         {
353                 programCollection.glslSources.add("vert")
354                                 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
355
356                 programCollection.glslSources.add("tesc")
357                                 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
358
359                 std::ostringstream src;
360
361                 src << "#version 450\n"
362                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
363                         << "layout(isolines) in;\n"
364                         << "layout(set = 0, binding = 0, std430) buffer Output\n"
365                         << "{\n"
366                         << "  uint result[];\n"
367                         << "};\n"
368                         << "\n"
369                         << "void main (void)\n"
370                         << "{\n"
371                         << bdy.str()
372                         << "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
373                         << "}\n";
374
375                 programCollection.glslSources.add("tese")
376                                 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
377         }
378         else
379         {
380                 DE_FATAL("Unsupported shader stage");
381         }
382 }
383
384 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
385 {
386         if (!subgroups::isSubgroupSupported(context))
387                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
388
389         if (!areSubgroupOperationsSupportedForStage(
390                                 context, caseDef.shaderStage))
391         {
392                 if (areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
393                 {
394                         return tcu::TestStatus::fail(
395                                            "Shader stage " + getShaderStageName(caseDef.shaderStage) +
396                                            " is required to support subgroup operations!");
397                 }
398                 else
399                 {
400                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
401                 }
402         }
403
404         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
405         {
406                 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
407         }
408
409         //Tests which don't use the SSBO
410         if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
411         {
412                 return makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
413         }
414
415         if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
416                         (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
417         {
418                 if (!subgroups::isVertexSSBOSupportedForDevice(context))
419                 {
420                         TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
421                 }
422         }
423
424         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
425         {
426                 return makeComputeTest(
427                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkCompute);
428         }
429         else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
430         {
431                 return makeFragmentTest(
432                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkFragment);
433         }
434         else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
435         {
436                 return makeVertexTest(
437                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
438         }
439         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
440         {
441                 return makeGeometryTest(
442                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
443         }
444         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
445         {
446                 return makeTessellationControlTest(
447                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
448         }
449         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
450         {
451                 return makeTessellationEvaluationTest(
452                                    context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
453         }
454         else
455         {
456                 TCU_THROW(InternalError, "Unhandled shader stage");
457         }
458 }
459
460 tcu::TestCaseGroup* createSubgroupsBuiltinMaskVarTests(tcu::TestContext& testCtx)
461 {
462         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
463                         testCtx, "builtin_mask_var", "Subgroup builtin mask variable tests"));
464
465         const char* const all_stages_vars[] =
466         {
467                 "SubgroupEqMask",
468                 "SubgroupGeMask",
469                 "SubgroupGtMask",
470                 "SubgroupLeMask",
471                 "SubgroupLtMask",
472         };
473
474         const VkShaderStageFlags stages[] =
475         {
476                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
477                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
478                 VK_SHADER_STAGE_GEOMETRY_BIT,
479                 VK_SHADER_STAGE_VERTEX_BIT,
480                 VK_SHADER_STAGE_FRAGMENT_BIT,
481                 VK_SHADER_STAGE_COMPUTE_BIT,
482         };
483
484         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
485         {
486                 const VkShaderStageFlags stage = stages[stageIndex];
487
488                 for (int a = 0; a < DE_LENGTH_OF_ARRAY(all_stages_vars); ++a)
489                 {
490                         const std::string var = all_stages_vars[a];
491
492                         CaseDefinition caseDef = {"gl_" + var, stage, false};
493
494                         addFunctionCaseWithPrograms(group.get(),
495                                                                                 de::toLower(var) + "_" +
496                                                                                 getShaderStageName(stage), "",
497                                                                                 initPrograms, test, caseDef);
498
499                         if (VK_SHADER_STAGE_VERTEX_BIT == stage)
500                         {
501                                 caseDef.noSSBO = true;
502                                 addFunctionCaseWithPrograms(group.get(),
503                                                         de::toLower(var) + "_" +
504                                                         getShaderStageName(stage)+"_framebuffer", "",
505                                                         initFrameBufferPrograms, test, caseDef);
506                         }
507                 }
508         }
509
510         return group.release();
511 }
512
513 } // subgroups
514 } // vkt