1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
5 * Copyright (c) 2017 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
12 * http://www.apache.org/licenses/LICENSE-2.0
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
22 * \brief Subgroups Tests
23 */ /*--------------------------------------------------------------------*/
25 #include "vktSubgroupsQuadTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
40 OPTYPE_QUAD_BROADCAST = 0,
41 OPTYPE_QUAD_SWAP_HORIZONTAL,
42 OPTYPE_QUAD_SWAP_VERTICAL,
43 OPTYPE_QUAD_SWAP_DIAGONAL,
47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48 deUint32 width, deUint32)
50 return vkt::subgroups::check(datas, width, 1);
53 static bool checkCompute(std::vector<const void*> datas,
54 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
57 return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
60 std::string getOpTypeName(int opType)
65 DE_FATAL("Unsupported op type");
67 case OPTYPE_QUAD_BROADCAST:
68 return "subgroupQuadBroadcast";
69 case OPTYPE_QUAD_SWAP_HORIZONTAL:
70 return "subgroupQuadSwapHorizontal";
71 case OPTYPE_QUAD_SWAP_VERTICAL:
72 return "subgroupQuadSwapVertical";
73 case OPTYPE_QUAD_SWAP_DIAGONAL:
74 return "subgroupQuadSwapDiagonal";
81 VkShaderStageFlags shaderStage;
86 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
88 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
89 std::string swapTable[OPTYPE_LAST];
91 subgroups::setFragmentShaderFrameBuffer(programCollection);
93 if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
94 subgroups::setVertexShaderFrameBuffer(programCollection);
96 swapTable[OPTYPE_QUAD_BROADCAST] = "";
97 swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = " const uint swapTable[4] = {1, 0, 3, 2};\n";
98 swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = " const uint swapTable[4] = {2, 3, 0, 1};\n";
99 swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = " const uint swapTable[4] = {3, 2, 1, 0};\n";
101 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
103 std::ostringstream vertexSrc;
104 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
105 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
106 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
107 << "layout(location = 0) in highp vec4 in_position;\n"
108 << "layout(location = 0) out float result;\n"
109 << "layout(set = 0, binding = 0) uniform Buffer1\n"
111 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
114 << "void main (void)\n"
116 << " uvec4 mask = subgroupBallot(true);\n"
117 << swapTable[caseDef.opType];
119 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
121 vertexSrc << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
122 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
123 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
127 vertexSrc << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
128 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
129 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
132 vertexSrc << " if (subgroupBallotBitExtract(mask, otherID))\n"
134 << " result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
138 << " result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
140 << " gl_Position = in_position;\n"
141 << " gl_PointSize = 1.0f;\n"
143 programCollection.glslSources.add("vert")
144 << glu::VertexSource(vertexSrc.str()) << buildOptions;
146 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
148 std::ostringstream geometry;
150 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
151 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
152 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
153 << "layout(points) in;\n"
154 << "layout(points, max_vertices = 1) out;\n"
155 << "layout(location = 0) out float out_color;\n"
157 << "layout(set = 0, binding = 0) uniform Buffer1\n"
159 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
162 << "void main (void)\n"
164 << " uvec4 mask = subgroupBallot(true);\n"
165 << swapTable[caseDef.opType];
167 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
169 geometry << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
171 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
175 geometry << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
180 geometry << " if (subgroupBallotBitExtract(mask, otherID))\n"
182 << " out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
186 << " out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
188 << " gl_Position = gl_in[0].gl_Position;\n"
189 << " EmitVertex();\n"
190 << " EndPrimitive();\n"
193 programCollection.glslSources.add("geometry")
194 << glu::GeometrySource(geometry.str()) << buildOptions;
196 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
198 std::ostringstream controlSource;
200 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
201 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
202 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
203 << "layout(vertices = 2) out;\n"
204 << "layout(location = 0) out float out_color[];\n"
205 << "layout(set = 0, binding = 0) uniform Buffer1\n"
207 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
210 << "void main (void)\n"
212 << " if (gl_InvocationID == 0)\n"
214 << " gl_TessLevelOuter[0] = 1.0f;\n"
215 << " gl_TessLevelOuter[1] = 1.0f;\n"
217 << " uvec4 mask = subgroupBallot(true);\n"
218 << swapTable[caseDef.opType];
220 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
222 controlSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
223 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
224 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
228 controlSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
229 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
230 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
233 controlSource << " if (subgroupBallotBitExtract(mask, otherID))\n"
235 << " out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
239 << " out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
241 << " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
244 programCollection.glslSources.add("tesc")
245 << glu::TessellationControlSource(controlSource.str()) << buildOptions;
246 subgroups::setTesEvalShaderFrameBuffer(programCollection);
248 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
250 ostringstream evaluationSource;
251 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
253 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
254 << "layout(isolines, equal_spacing, ccw ) in;\n"
255 << "layout(location = 0) out float out_color;\n"
256 << "layout(set = 0, binding = 0) uniform Buffer1\n"
258 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
261 << "void main (void)\n"
263 << " uvec4 mask = subgroupBallot(true);\n"
264 << swapTable[caseDef.opType];
266 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
268 evaluationSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
269 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
270 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
274 evaluationSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
275 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
276 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
279 evaluationSource << " if (subgroupBallotBitExtract(mask, otherID))\n"
281 << " out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
285 << " out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
287 << " gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
290 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
291 programCollection.glslSources.add("tese")
292 << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
296 DE_FATAL("Unsupported shader stage");
300 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
302 std::string swapTable[OPTYPE_LAST];
303 swapTable[OPTYPE_QUAD_BROADCAST] = "";
304 swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = " const uint swapTable[4] = {1, 0, 3, 2};\n";
305 swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = " const uint swapTable[4] = {2, 3, 0, 1};\n";
306 swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = " const uint swapTable[4] = {3, 2, 1, 0};\n";
308 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
310 std::ostringstream src;
312 src << "#version 450\n"
313 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
314 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
315 << "layout (local_size_x_id = 0, local_size_y_id = 1, "
316 "local_size_z_id = 2) in;\n"
317 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
319 << " uint result[];\n"
321 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
323 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
326 << "void main (void)\n"
328 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
329 << " highp uint offset = globalSize.x * ((globalSize.y * "
330 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
331 "gl_GlobalInvocationID.x;\n"
332 << " uvec4 mask = subgroupBallot(true);\n"
333 << swapTable[caseDef.opType];
336 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
338 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
339 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
340 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
344 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
345 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
346 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
349 src << " if (subgroupBallotBitExtract(mask, otherID))\n"
351 << " result[offset] = (op == data[otherID]) ? 1 : 0;\n"
355 << " result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
359 programCollection.glslSources.add("comp")
360 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
364 std::ostringstream src;
365 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
367 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
368 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
369 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
373 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
374 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
375 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
377 const string sourceType = src.str();
380 const string vertex =
382 "#extension GL_KHR_shader_subgroup_quad: enable\n"
383 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
384 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
388 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
390 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
395 " uvec4 mask = subgroupBallot(true);\n"
396 + swapTable[caseDef.opType]
398 " if (subgroupBallotBitExtract(mask, otherID))\n"
400 " result[gl_VertexIndex] = (op == data[otherID]) ? 1 : 0;\n"
404 " result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
406 " float pixelSize = 2.0f/1024.0f;\n"
407 " float pixelPosition = pixelSize/2.0f - 1.0f;\n"
408 " gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
410 programCollection.glslSources.add("vert")
411 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
417 "#extension GL_KHR_shader_subgroup_quad: enable\n"
418 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
419 "layout(vertices=1) out;\n"
420 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
424 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
426 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
431 " uvec4 mask = subgroupBallot(true);\n"
432 + swapTable[caseDef.opType]
434 " if (subgroupBallotBitExtract(mask, otherID))\n"
436 " result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
440 " result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
442 " if (gl_InvocationID == 0)\n"
444 " gl_TessLevelOuter[0] = 1.0f;\n"
445 " gl_TessLevelOuter[1] = 1.0f;\n"
447 " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
449 programCollection.glslSources.add("tesc")
450 << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
456 "#extension GL_KHR_shader_subgroup_quad: enable\n"
457 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
458 "layout(isolines) in;\n"
459 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
463 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
465 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
470 " uvec4 mask = subgroupBallot(true);\n"
471 + swapTable[caseDef.opType]
473 " if (subgroupBallotBitExtract(mask, otherID))\n"
475 " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
479 " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
481 " float pixelSize = 2.0f/1024.0f;\n"
482 " gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
484 programCollection.glslSources.add("tese")
485 << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
489 const string geometry =
491 "#extension GL_KHR_shader_subgroup_quad: enable\n"
492 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
493 "layout(${TOPOLOGY}) in;\n"
494 "layout(points, max_vertices = 1) out;\n"
495 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
499 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
501 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
506 " uvec4 mask = subgroupBallot(true);\n"
507 + swapTable[caseDef.opType]
509 " if (subgroupBallotBitExtract(mask, otherID))\n"
511 " result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
515 " result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
517 " gl_Position = gl_in[0].gl_Position;\n"
521 subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
522 programCollection.glslSources);
526 const string fragment =
528 "#extension GL_KHR_shader_subgroup_quad: enable\n"
529 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
530 "layout(location = 0) out uint result;\n"
531 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
533 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
537 " uvec4 mask = subgroupBallot(true);\n"
538 + swapTable[caseDef.opType]
540 " if (subgroupBallotBitExtract(mask, otherID))\n"
542 " result = (op == data[otherID]) ? 1 : 0;\n"
546 " result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
549 programCollection.glslSources.add("fragment")
550 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
552 subgroups::addNoSubgroupShader(programCollection);
556 void supportedCheck (Context& context, CaseDefinition caseDef)
558 if (!subgroups::isSubgroupSupported(context))
559 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
561 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
562 TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
565 if (subgroups::isDoubleFormat(caseDef.format) &&
566 !subgroups::isDoubleSupportedForDevice(context))
568 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
572 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
574 if (!subgroups::areSubgroupOperationsSupportedForStage(
575 context, caseDef.shaderStage))
577 if (subgroups::areSubgroupOperationsRequiredForStage(
578 caseDef.shaderStage))
580 return tcu::TestStatus::fail(
582 subgroups::getShaderStageName(caseDef.shaderStage) +
583 " is required to support subgroup operations!");
587 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
591 subgroups::SSBOData inputData;
592 inputData.format = caseDef.format;
593 inputData.numElements = subgroups::maxSupportedSubgroupSize();
594 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;;
596 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
597 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
598 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
599 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
600 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
601 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
602 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
603 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
605 TCU_THROW(InternalError, "Unhandled shader stage");
609 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
611 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
613 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
615 return tcu::TestStatus::fail(
617 subgroups::getShaderStageName(caseDef.shaderStage) +
618 " is required to support subgroup operations!");
620 subgroups::SSBOData inputData;
621 inputData.format = caseDef.format;
622 inputData.numElements = subgroups::maxSupportedSubgroupSize();
623 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
625 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
629 VkPhysicalDeviceSubgroupProperties subgroupProperties;
630 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
631 subgroupProperties.pNext = DE_NULL;
633 VkPhysicalDeviceProperties2 properties;
634 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
635 properties.pNext = &subgroupProperties;
637 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
639 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage & subgroupProperties.supportedStages);
641 if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
643 if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
644 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
646 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
649 if ((VkShaderStageFlagBits)0u == stages)
650 TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
652 subgroups::SSBOData inputData;
653 inputData.format = caseDef.format;
654 inputData.numElements = subgroups::maxSupportedSubgroupSize();
655 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
656 inputData.binding = 4u;
657 inputData.stages = stages;
659 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
668 tcu::TestCaseGroup* createSubgroupsQuadTests(tcu::TestContext& testCtx)
670 de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
671 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
672 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
673 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
674 de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
675 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
677 const VkFormat formats[] =
679 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
680 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
681 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
682 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
683 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
684 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
685 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
686 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
687 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
690 const VkShaderStageFlags stages[] =
692 VK_SHADER_STAGE_VERTEX_BIT,
693 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
694 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
695 VK_SHADER_STAGE_GEOMETRY_BIT,
698 for (int direction = 0; direction < 4; ++direction)
700 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
702 const VkFormat format = formats[formatIndex];
704 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
706 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
707 std::ostringstream name;
708 name << de::toLower(op);
710 if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
712 name << "_" << direction;
718 // We don't need direction for swap operations.
723 name << "_" << subgroups::getFormatNameForGLSL(format);
726 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format, direction};
727 addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
731 const CaseDefinition caseDef =
734 VK_SHADER_STAGE_ALL_GRAPHICS,
738 addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
740 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
742 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
743 addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
744 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
751 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
752 testCtx, "quad", "Subgroup quad category tests"));
754 group->addChild(graphicGroup.release());
755 group->addChild(computeGroup.release());
756 group->addChild(framebufferGroup.release());
758 return group.release();