1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
5 * Copyright (c) 2017 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
12 * http://www.apache.org/licenses/LICENSE-2.0
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
22 * \brief Subgroups Tests
23 */ /*--------------------------------------------------------------------*/
25 #include "vktSubgroupsBallotBroadcastTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
41 OPTYPE_BROADCAST_FIRST,
45 static bool checkVertexPipelineStages(std::vector<const void*> datas,
46 deUint32 width, deUint32)
48 const deUint32* data =
49 reinterpret_cast<const deUint32*>(datas[0]);
50 for (deUint32 x = 0; x < width; ++x)
52 deUint32 val = data[x];
63 static bool checkFragment(std::vector<const void*> datas,
64 deUint32 width, deUint32 height, deUint32)
66 const deUint32* data =
67 reinterpret_cast<const deUint32*>(datas[0]);
68 for (deUint32 x = 0; x < width; ++x)
70 for (deUint32 y = 0; y < height; ++y)
72 deUint32 val = data[x * height + y];
84 static bool checkCompute(std::vector<const void*> datas,
85 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
88 const deUint32* data =
89 reinterpret_cast<const deUint32*>(datas[0]);
91 for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
93 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
95 for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
97 for (deUint32 lX = 0; lX < localSize[0]; ++lX)
99 for (deUint32 lY = 0; lY < localSize[1]; ++lY)
101 for (deUint32 lZ = 0; lZ < localSize[2];
104 const deUint32 globalInvocationX =
105 nX * localSize[0] + lX;
106 const deUint32 globalInvocationY =
107 nY * localSize[1] + lY;
108 const deUint32 globalInvocationZ =
109 nZ * localSize[2] + lZ;
111 const deUint32 globalSizeX =
112 numWorkgroups[0] * localSize[0];
113 const deUint32 globalSizeY =
114 numWorkgroups[1] * localSize[1];
116 const deUint32 offset =
123 if (0x3 != data[offset])
138 std::string getOpTypeName(int opType)
143 DE_FATAL("Unsupported op type");
144 case OPTYPE_BROADCAST:
145 return "subgroupBroadcast";
146 case OPTYPE_BROADCAST_FIRST:
147 return "subgroupBroadcastFirst";
152 struct CaseDefinition
155 VkShaderStageFlags shaderStage;
160 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
162 std::ostringstream bdy;
164 bdy << " uint tempResult = 0;\n";
166 if (OPTYPE_BROADCAST == caseDef.opType)
168 bdy << " tempResult = 0x3;\n";
170 for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
173 << " const uint id = " << i << ";\n"
174 << " " << subgroups::getFormatNameForGLSL(caseDef.format)
175 << " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
176 << " if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
178 << " if (op != data1[id])\n"
180 << " tempResult = 0;\n"
188 bdy << " uint firstActive = 0;\n"
189 << " for (uint i = 0; i < gl_SubgroupSize; i++)\n"
191 << " if (subgroupBallotBitExtract(mask, i))\n"
193 << " firstActive = i;\n"
197 << " tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
198 << " // make the firstActive invocation inactive now\n"
199 << " if (firstActive == gl_SubgroupInvocationID)\n"
201 << " for (uint i = 0; i < gl_SubgroupSize; i++)\n"
203 << " if (subgroupBallotBitExtract(mask, i))\n"
205 << " firstActive = i;\n"
209 << " tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
213 << " // the firstActive invocation didn't partake in the second result so set it to true\n"
214 << " tempResult |= 0x2;\n"
218 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
220 std::ostringstream src;
221 std::ostringstream fragmentSrc;
223 src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
224 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
225 << "layout(location = 0) in highp vec4 in_position;\n"
226 << "layout(location = 0) out float out_color;\n"
227 << "layout(set = 0, binding = 0) uniform Buffer1\n"
229 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
232 << "void main (void)\n"
234 << " uvec4 mask = subgroupBallot(true);\n"
236 << " out_color = float(tempResult);\n"
237 << " gl_Position = in_position;\n"
238 << " gl_PointSize = 1.0f;\n"
241 programCollection.glslSources.add("vert")
242 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
244 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
245 << "layout(location = 0) in float in_color;\n"
246 << "layout(location = 0) out uint out_color;\n"
249 << " out_color = uint(in_color);\n"
251 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
255 DE_FATAL("Unsupported shader stage");
259 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
261 std::ostringstream bdy;
263 bdy << " uint tempResult = 0;\n";
265 if (OPTYPE_BROADCAST == caseDef.opType)
267 bdy << " tempResult = 0x3;\n";
269 for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
272 << " const uint id = " << i << ";\n"
273 << " " << subgroups::getFormatNameForGLSL(caseDef.format)
274 << " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
275 << " if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
277 << " if (op != data1[id])\n"
279 << " tempResult = 0;\n"
287 bdy << " uint firstActive = 0;\n"
288 << " for (uint i = 0; i < gl_SubgroupSize; i++)\n"
290 << " if (subgroupBallotBitExtract(mask, i))\n"
292 << " firstActive = i;\n"
296 << " tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
297 << " // make the firstActive invocation inactive now\n"
298 << " if (firstActive == gl_SubgroupInvocationID)\n"
300 << " for (uint i = 0; i < gl_SubgroupSize; i++)\n"
302 << " if (subgroupBallotBitExtract(mask, i))\n"
304 << " firstActive = i;\n"
308 << " tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
312 << " // the firstActive invocation didn't partake in the second result so set it to true\n"
313 << " tempResult |= 0x2;\n"
317 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
319 std::ostringstream src;
321 src << "#version 450\n"
322 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
323 << "layout (local_size_x_id = 0, local_size_y_id = 1, "
324 "local_size_z_id = 2) in;\n"
325 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
327 << " uint result[];\n"
329 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
331 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
334 << "void main (void)\n"
336 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
337 << " highp uint offset = globalSize.x * ((globalSize.y * "
338 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
339 "gl_GlobalInvocationID.x;\n"
340 << " uvec4 mask = subgroupBallot(true);\n"
342 << " result[offset] = tempResult;\n"
345 programCollection.glslSources.add("comp")
346 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
348 else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
350 programCollection.glslSources.add("vert")
351 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
353 std::ostringstream frag;
355 frag << "#version 450\n"
356 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
357 << "layout(location = 0) out uint result;\n"
358 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer1\n"
360 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
362 << "void main (void)\n"
364 << " uvec4 mask = subgroupBallot(true);\n"
366 << " result = tempResult;\n"
369 programCollection.glslSources.add("frag")
370 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
372 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
374 std::ostringstream src;
376 src << "#version 450\n"
377 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
378 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
380 << " uint result[];\n"
382 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
384 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
387 << "void main (void)\n"
389 << " uvec4 mask = subgroupBallot(true);\n"
391 << " result[gl_VertexIndex] = tempResult;\n"
392 << " gl_PointSize = 1.0f;\n"
395 programCollection.glslSources.add("vert")
396 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
398 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
400 programCollection.glslSources.add("vert")
401 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
403 std::ostringstream src;
405 src << "#version 450\n"
406 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
407 << "layout(points) in;\n"
408 << "layout(points, max_vertices = 1) out;\n"
409 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
411 << " uint result[];\n"
413 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
415 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
418 << "void main (void)\n"
420 << " uvec4 mask = subgroupBallot(true);\n"
422 << " result[gl_PrimitiveIDIn] = tempResult;\n"
425 programCollection.glslSources.add("geom")
426 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
428 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
430 programCollection.glslSources.add("vert")
431 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
433 programCollection.glslSources.add("tese")
434 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
436 std::ostringstream src;
438 src << "#version 450\n"
439 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
440 << "layout(vertices=1) out;\n"
441 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
443 << " uint result[];\n"
445 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
447 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
450 << "void main (void)\n"
452 << " uvec4 mask = subgroupBallot(true);\n"
454 << " result[gl_PrimitiveID] = tempResult;\n"
457 programCollection.glslSources.add("tesc")
458 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
460 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
462 programCollection.glslSources.add("vert")
463 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
465 programCollection.glslSources.add("tesc")
466 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
468 std::ostringstream src;
470 src << "#version 450\n"
471 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
472 << "layout(isolines) in;\n"
473 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
475 << " uint result[];\n"
477 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
479 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
482 << "void main (void)\n"
484 << " uvec4 mask = subgroupBallot(true);\n"
486 << " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
489 programCollection.glslSources.add("tese")
490 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
494 DE_FATAL("Unsupported shader stage");
498 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
500 if (!subgroups::isSubgroupSupported(context))
501 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
503 if (!subgroups::areSubgroupOperationsSupportedForStage(
504 context, caseDef.shaderStage))
506 if (subgroups::areSubgroupOperationsRequiredForStage(
507 caseDef.shaderStage))
509 return tcu::TestStatus::fail(
511 subgroups::getShaderStageName(caseDef.shaderStage) +
512 " is required to support subgroup operations!");
516 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
520 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
522 TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
525 if (subgroups::isDoubleFormat(caseDef.format) &&
526 !subgroups::isDoubleSupportedForDevice(context))
528 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
531 //Tests which don't use the SSBO
532 if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
534 subgroups::SSBOData inputData[1];
535 inputData[0].format = caseDef.format;
536 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
537 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
539 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages);
542 if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
543 (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
545 if (!subgroups::isVertexSSBOSupportedForDevice(context))
547 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
551 if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
553 subgroups::SSBOData inputData[1];
554 inputData[0].format = caseDef.format;
555 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
556 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
558 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
559 inputData, 1, checkFragment);
561 else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
563 subgroups::SSBOData inputData[1];
564 inputData[0].format = caseDef.format;
565 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
566 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
568 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
569 inputData, 1, checkCompute);
571 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
573 subgroups::SSBOData inputData[1];
574 inputData[0].format = caseDef.format;
575 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
576 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
578 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
579 inputData, 1, checkVertexPipelineStages);
581 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
583 subgroups::SSBOData inputData[1];
584 inputData[0].format = caseDef.format;
585 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
586 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
588 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
589 inputData, 1, checkVertexPipelineStages);
591 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
593 subgroups::SSBOData inputData[1];
594 inputData[0].format = caseDef.format;
595 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
596 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
598 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
599 inputData, 1, checkVertexPipelineStages);
601 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
603 subgroups::SSBOData inputData[1];
604 inputData[0].format = caseDef.format;
605 inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
606 inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
608 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
609 inputData, 1, checkVertexPipelineStages);
613 TCU_THROW(InternalError, "Unhandled shader stage");
622 tcu::TestCaseGroup* createSubgroupsBallotBroadcastTests(tcu::TestContext& testCtx)
624 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
625 testCtx, "ballot_broadcast", "Subgroup ballot broadcast category tests"));
627 const VkShaderStageFlags stages[] =
629 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
630 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
631 VK_SHADER_STAGE_GEOMETRY_BIT,
632 VK_SHADER_STAGE_VERTEX_BIT,
633 VK_SHADER_STAGE_FRAGMENT_BIT,
634 VK_SHADER_STAGE_COMPUTE_BIT
637 const VkFormat formats[] =
639 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
640 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
641 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
642 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
643 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
644 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
645 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
646 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
647 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
650 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
652 const VkShaderStageFlags stage = stages[stageIndex];
654 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
656 const VkFormat format = formats[formatIndex];
658 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
660 CaseDefinition caseDef = {opTypeIndex, stage, format, false};
662 std::ostringstream name;
664 std::string op = getOpTypeName(opTypeIndex);
666 name << de::toLower(op) << "_" << subgroups::getFormatNameForGLSL(format)
667 << "_" << getShaderStageName(stage);
669 addFunctionCaseWithPrograms(group.get(), name.str(),
670 "", initPrograms, test, caseDef);
672 if (VK_SHADER_STAGE_VERTEX_BIT == stage )
674 caseDef.noSSBO = true;
675 addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
676 initFrameBufferPrograms, test, caseDef);
682 return group.release();