1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
5 * Copyright (c) 2017 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
12 * http://www.apache.org/licenses/LICENSE-2.0
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
22 * \brief Subgroups Tests
23 */ /*--------------------------------------------------------------------*/
25 #include "vktSubgroupsArithmeticTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
64 static bool checkVertexPipelineStages(std::vector<const void*> datas,
65 deUint32 width, deUint32)
67 return vkt::subgroups::check(datas, width, 0x3);
70 static bool checkCompute(std::vector<const void*> datas,
71 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
74 return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x3);
77 std::string getOpTypeName(int opType)
82 DE_FATAL("Unsupported op type");
98 case OPTYPE_INCLUSIVE_ADD:
99 return "subgroupInclusiveAdd";
100 case OPTYPE_INCLUSIVE_MUL:
101 return "subgroupInclusiveMul";
102 case OPTYPE_INCLUSIVE_MIN:
103 return "subgroupInclusiveMin";
104 case OPTYPE_INCLUSIVE_MAX:
105 return "subgroupInclusiveMax";
106 case OPTYPE_INCLUSIVE_AND:
107 return "subgroupInclusiveAnd";
108 case OPTYPE_INCLUSIVE_OR:
109 return "subgroupInclusiveOr";
110 case OPTYPE_INCLUSIVE_XOR:
111 return "subgroupInclusiveXor";
112 case OPTYPE_EXCLUSIVE_ADD:
113 return "subgroupExclusiveAdd";
114 case OPTYPE_EXCLUSIVE_MUL:
115 return "subgroupExclusiveMul";
116 case OPTYPE_EXCLUSIVE_MIN:
117 return "subgroupExclusiveMin";
118 case OPTYPE_EXCLUSIVE_MAX:
119 return "subgroupExclusiveMax";
120 case OPTYPE_EXCLUSIVE_AND:
121 return "subgroupExclusiveAnd";
122 case OPTYPE_EXCLUSIVE_OR:
123 return "subgroupExclusiveOr";
124 case OPTYPE_EXCLUSIVE_XOR:
125 return "subgroupExclusiveXor";
129 std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
134 DE_FATAL("Unsupported op type");
137 case OPTYPE_INCLUSIVE_ADD:
138 case OPTYPE_EXCLUSIVE_ADD:
139 return lhs + " + " + rhs;
141 case OPTYPE_INCLUSIVE_MUL:
142 case OPTYPE_EXCLUSIVE_MUL:
143 return lhs + " * " + rhs;
145 case OPTYPE_INCLUSIVE_MIN:
146 case OPTYPE_EXCLUSIVE_MIN:
150 return "min(" + lhs + ", " + rhs + ")";
151 case VK_FORMAT_R32_SFLOAT:
152 case VK_FORMAT_R64_SFLOAT:
153 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
154 case VK_FORMAT_R32G32_SFLOAT:
155 case VK_FORMAT_R32G32B32_SFLOAT:
156 case VK_FORMAT_R32G32B32A32_SFLOAT:
157 case VK_FORMAT_R64G64_SFLOAT:
158 case VK_FORMAT_R64G64B64_SFLOAT:
159 case VK_FORMAT_R64G64B64A64_SFLOAT:
160 return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
163 case OPTYPE_INCLUSIVE_MAX:
164 case OPTYPE_EXCLUSIVE_MAX:
168 return "max(" + lhs + ", " + rhs + ")";
169 case VK_FORMAT_R32_SFLOAT:
170 case VK_FORMAT_R64_SFLOAT:
171 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
172 case VK_FORMAT_R32G32_SFLOAT:
173 case VK_FORMAT_R32G32B32_SFLOAT:
174 case VK_FORMAT_R32G32B32A32_SFLOAT:
175 case VK_FORMAT_R64G64_SFLOAT:
176 case VK_FORMAT_R64G64B64_SFLOAT:
177 case VK_FORMAT_R64G64B64A64_SFLOAT:
178 return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
181 case OPTYPE_INCLUSIVE_AND:
182 case OPTYPE_EXCLUSIVE_AND:
186 return lhs + " & " + rhs;
187 case VK_FORMAT_R8_USCALED:
188 return lhs + " && " + rhs;
189 case VK_FORMAT_R8G8_USCALED:
190 return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
191 case VK_FORMAT_R8G8B8_USCALED:
192 return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
193 case VK_FORMAT_R8G8B8A8_USCALED:
194 return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
197 case OPTYPE_INCLUSIVE_OR:
198 case OPTYPE_EXCLUSIVE_OR:
202 return lhs + " | " + rhs;
203 case VK_FORMAT_R8_USCALED:
204 return lhs + " || " + rhs;
205 case VK_FORMAT_R8G8_USCALED:
206 return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
207 case VK_FORMAT_R8G8B8_USCALED:
208 return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
209 case VK_FORMAT_R8G8B8A8_USCALED:
210 return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
213 case OPTYPE_INCLUSIVE_XOR:
214 case OPTYPE_EXCLUSIVE_XOR:
218 return lhs + " ^ " + rhs;
219 case VK_FORMAT_R8_USCALED:
220 return lhs + " ^^ " + rhs;
221 case VK_FORMAT_R8G8_USCALED:
222 return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
223 case VK_FORMAT_R8G8B8_USCALED:
224 return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
225 case VK_FORMAT_R8G8B8A8_USCALED:
226 return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
231 std::string getIdentity(int opType, vk::VkFormat format)
233 bool isFloat = false;
235 bool isUnsigned = false;
240 DE_FATAL("Unhandled format!");
242 case VK_FORMAT_R32_SINT:
243 case VK_FORMAT_R32G32_SINT:
244 case VK_FORMAT_R32G32B32_SINT:
245 case VK_FORMAT_R32G32B32A32_SINT:
248 case VK_FORMAT_R32_UINT:
249 case VK_FORMAT_R32G32_UINT:
250 case VK_FORMAT_R32G32B32_UINT:
251 case VK_FORMAT_R32G32B32A32_UINT:
254 case VK_FORMAT_R32_SFLOAT:
255 case VK_FORMAT_R32G32_SFLOAT:
256 case VK_FORMAT_R32G32B32_SFLOAT:
257 case VK_FORMAT_R32G32B32A32_SFLOAT:
258 case VK_FORMAT_R64_SFLOAT:
259 case VK_FORMAT_R64G64_SFLOAT:
260 case VK_FORMAT_R64G64B64_SFLOAT:
261 case VK_FORMAT_R64G64B64A64_SFLOAT:
264 case VK_FORMAT_R8_USCALED:
265 case VK_FORMAT_R8G8_USCALED:
266 case VK_FORMAT_R8G8B8_USCALED:
267 case VK_FORMAT_R8G8B8A8_USCALED:
268 break; // bool types are not anything
274 DE_FATAL("Unsupported op type");
277 case OPTYPE_INCLUSIVE_ADD:
278 case OPTYPE_EXCLUSIVE_ADD:
279 return subgroups::getFormatNameForGLSL(format) + "(0)";
281 case OPTYPE_INCLUSIVE_MUL:
282 case OPTYPE_EXCLUSIVE_MUL:
283 return subgroups::getFormatNameForGLSL(format) + "(1)";
285 case OPTYPE_INCLUSIVE_MIN:
286 case OPTYPE_EXCLUSIVE_MIN:
289 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
293 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
297 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
301 DE_FATAL("Unhandled case");
305 case OPTYPE_INCLUSIVE_MAX:
306 case OPTYPE_EXCLUSIVE_MAX:
309 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
313 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
317 return subgroups::getFormatNameForGLSL(format) + "(0)";
321 DE_FATAL("Unhandled case");
325 case OPTYPE_INCLUSIVE_AND:
326 case OPTYPE_EXCLUSIVE_AND:
327 return subgroups::getFormatNameForGLSL(format) + "(~0)";
329 case OPTYPE_INCLUSIVE_OR:
330 case OPTYPE_EXCLUSIVE_OR:
331 return subgroups::getFormatNameForGLSL(format) + "(0)";
333 case OPTYPE_INCLUSIVE_XOR:
334 case OPTYPE_EXCLUSIVE_XOR:
335 return subgroups::getFormatNameForGLSL(format) + "(0)";
339 std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
341 std::string formatName = subgroups::getFormatNameForGLSL(format);
345 return "all(equal(" + lhs + ", " + rhs + "))";
346 case VK_FORMAT_R8_USCALED:
347 case VK_FORMAT_R32_UINT:
348 case VK_FORMAT_R32_SINT:
349 return "(" + lhs + " == " + rhs + ")";
350 case VK_FORMAT_R32_SFLOAT:
351 case VK_FORMAT_R64_SFLOAT:
355 return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
357 case OPTYPE_INCLUSIVE_MIN:
358 case OPTYPE_EXCLUSIVE_MIN:
360 case OPTYPE_INCLUSIVE_MAX:
361 case OPTYPE_EXCLUSIVE_MAX:
362 return "(" + lhs + " == " + rhs + ")";
364 case VK_FORMAT_R32G32_SFLOAT:
365 case VK_FORMAT_R32G32B32_SFLOAT:
366 case VK_FORMAT_R32G32B32A32_SFLOAT:
367 case VK_FORMAT_R64G64_SFLOAT:
368 case VK_FORMAT_R64G64B64_SFLOAT:
369 case VK_FORMAT_R64G64B64A64_SFLOAT:
373 return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
375 case OPTYPE_INCLUSIVE_MIN:
376 case OPTYPE_EXCLUSIVE_MIN:
378 case OPTYPE_INCLUSIVE_MAX:
379 case OPTYPE_EXCLUSIVE_MAX:
380 return "all(equal(" + lhs + ", " + rhs + "))";
385 struct CaseDefinition
388 VkShaderStageFlags shaderStage;
392 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
394 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
395 std::string indexVars;
396 std::ostringstream bdy;
398 subgroups::setFragmentShaderFrameBuffer(programCollection);
400 if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
401 subgroups::setVertexShaderFrameBuffer(programCollection);
403 switch (caseDef.opType)
406 indexVars = " uint start = 0, end = gl_SubgroupSize;\n";
408 case OPTYPE_INCLUSIVE_ADD:
409 case OPTYPE_INCLUSIVE_MUL:
410 case OPTYPE_INCLUSIVE_MIN:
411 case OPTYPE_INCLUSIVE_MAX:
412 case OPTYPE_INCLUSIVE_AND:
413 case OPTYPE_INCLUSIVE_OR:
414 case OPTYPE_INCLUSIVE_XOR:
415 indexVars = " uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
417 case OPTYPE_EXCLUSIVE_ADD:
418 case OPTYPE_EXCLUSIVE_MUL:
419 case OPTYPE_EXCLUSIVE_MIN:
420 case OPTYPE_EXCLUSIVE_MAX:
421 case OPTYPE_EXCLUSIVE_AND:
422 case OPTYPE_EXCLUSIVE_OR:
423 case OPTYPE_EXCLUSIVE_XOR:
424 indexVars = " uint start = 0, end = gl_SubgroupInvocationID;\n";
429 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
430 << getIdentity(caseDef.opType, caseDef.format) << ";\n"
431 << " uint tempResult = 0;\n"
432 << " for (uint index = start; index < end; index++)\n"
434 << " if (subgroupBallotBitExtract(mask, index))\n"
436 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
439 << " tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref",
440 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n"
441 << " if (1 == (gl_SubgroupInvocationID % 2))\n"
443 << " mask = subgroupBallot(true);\n"
444 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
445 << " for (uint index = start; index < end; index++)\n"
447 << " if (subgroupBallotBitExtract(mask, index))\n"
449 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
452 << " tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref",
453 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n"
457 << " tempResult |= 0x2;\n"
460 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
462 std::ostringstream vertexSrc;
463 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
464 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
465 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
466 << "layout(location = 0) in highp vec4 in_position;\n"
467 << "layout(location = 0) out float out_color;\n"
468 << "layout(set = 0, binding = 0) uniform Buffer1\n"
470 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
473 << "void main (void)\n"
475 << " uvec4 mask = subgroupBallot(true);\n"
477 << " out_color = float(tempResult);\n"
478 << " gl_Position = in_position;\n"
479 << " gl_PointSize = 1.0f;\n"
481 programCollection.glslSources.add("vert")
482 << glu::VertexSource(vertexSrc.str()) << buildOptions;
484 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
486 std::ostringstream geometry;
488 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
489 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
490 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
491 << "layout(points) in;\n"
492 << "layout(points, max_vertices = 1) out;\n"
493 << "layout(location = 0) out float out_color;\n"
494 << "layout(set = 0, binding = 0) uniform Buffer\n"
496 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
499 << "void main (void)\n"
501 << " uvec4 mask = subgroupBallot(true);\n"
503 << " out_color = float(tempResult);\n"
504 << " gl_Position = gl_in[0].gl_Position;\n"
505 << " EmitVertex();\n"
506 << " EndPrimitive();\n"
509 programCollection.glslSources.add("geometry")
510 << glu::GeometrySource(geometry.str()) << buildOptions;
512 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
514 std::ostringstream controlSource;
515 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
516 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
517 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
518 << "layout(vertices = 2) out;\n"
519 << "layout(location = 0) out float out_color[];\n"
520 << "layout(set = 0, binding = 0) uniform Buffer1\n"
522 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
525 << "void main (void)\n"
527 << " if (gl_InvocationID == 0)\n"
529 << " gl_TessLevelOuter[0] = 1.0f;\n"
530 << " gl_TessLevelOuter[1] = 1.0f;\n"
532 << " uvec4 mask = subgroupBallot(true);\n"
534 << " out_color[gl_InvocationID] = float(tempResult);"
535 << " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
539 programCollection.glslSources.add("tesc")
540 << glu::TessellationControlSource(controlSource.str()) << buildOptions;
541 subgroups::setTesEvalShaderFrameBuffer(programCollection);
543 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
546 std::ostringstream evaluationSource;
547 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
548 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
549 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
550 << "layout(isolines, equal_spacing, ccw ) in;\n"
551 << "layout(location = 0) out float out_color;\n"
552 << "layout(set = 0, binding = 0) uniform Buffer1\n"
554 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
557 << "void main (void)\n"
559 << " uvec4 mask = subgroupBallot(true);\n"
561 << " out_color = float(tempResult);\n"
562 << " gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
565 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
566 programCollection.glslSources.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
570 DE_FATAL("Unsupported shader stage");
574 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
576 std::string indexVars;
577 switch (caseDef.opType)
580 indexVars = " uint start = 0, end = gl_SubgroupSize;\n";
582 case OPTYPE_INCLUSIVE_ADD:
583 case OPTYPE_INCLUSIVE_MUL:
584 case OPTYPE_INCLUSIVE_MIN:
585 case OPTYPE_INCLUSIVE_MAX:
586 case OPTYPE_INCLUSIVE_AND:
587 case OPTYPE_INCLUSIVE_OR:
588 case OPTYPE_INCLUSIVE_XOR:
589 indexVars = " uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
591 case OPTYPE_EXCLUSIVE_ADD:
592 case OPTYPE_EXCLUSIVE_MUL:
593 case OPTYPE_EXCLUSIVE_MIN:
594 case OPTYPE_EXCLUSIVE_MAX:
595 case OPTYPE_EXCLUSIVE_AND:
596 case OPTYPE_EXCLUSIVE_OR:
597 case OPTYPE_EXCLUSIVE_XOR:
598 indexVars = " uint start = 0, end = gl_SubgroupInvocationID;\n";
604 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " ref = "
605 + getIdentity(caseDef.opType, caseDef.format) + ";\n"
606 " uint tempResult = 0;\n"
607 " for (uint index = start; index < end; index++)\n"
609 " if (subgroupBallotBitExtract(mask, index))\n"
611 " ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
614 " tempResult = " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x1 : 0;\n"
615 " if (1 == (gl_SubgroupInvocationID % 2))\n"
617 " mask = subgroupBallot(true);\n"
618 " ref = " + getIdentity(caseDef.opType, caseDef.format) + ";\n"
619 " for (uint index = start; index < end; index++)\n"
621 " if (subgroupBallotBitExtract(mask, index))\n"
623 " ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
626 " tempResult |= " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x2 : 0;\n"
630 " tempResult |= 0x2;\n"
633 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
635 std::ostringstream src;
637 src << "#version 450\n"
638 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
639 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
640 << "layout (local_size_x_id = 0, local_size_y_id = 1, "
641 "local_size_z_id = 2) in;\n"
642 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
644 << " uint result[];\n"
646 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
648 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
651 << "void main (void)\n"
653 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
654 << " highp uint offset = globalSize.x * ((globalSize.y * "
655 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
656 "gl_GlobalInvocationID.x;\n"
657 << " uvec4 mask = subgroupBallot(true);\n"
659 << " result[offset] = tempResult;\n"
662 programCollection.glslSources.add("comp")
663 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
668 const std::string vertex =
670 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
671 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
672 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
676 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
678 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
683 " uvec4 mask = subgroupBallot(true);\n"
685 " result[gl_VertexIndex] = tempResult;\n"
686 " float pixelSize = 2.0f/1024.0f;\n"
687 " float pixelPosition = pixelSize/2.0f - 1.0f;\n"
688 " gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
689 " gl_PointSize = 1.0f;\n"
691 programCollection.glslSources.add("vert")
692 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
696 const std::string tesc =
698 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
699 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
700 "layout(vertices=1) out;\n"
701 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
705 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
707 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
712 " uvec4 mask = subgroupBallot(true);\n"
714 " result[gl_PrimitiveID] = tempResult;\n"
715 " if (gl_InvocationID == 0)\n"
717 " gl_TessLevelOuter[0] = 1.0f;\n"
718 " gl_TessLevelOuter[1] = 1.0f;\n"
720 " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
722 programCollection.glslSources.add("tesc")
723 << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
727 const std::string tese =
729 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
730 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
731 "layout(isolines) in;\n"
732 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
736 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
738 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
743 " uvec4 mask = subgroupBallot(true);\n"
745 " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
746 " float pixelSize = 2.0f/1024.0f;\n"
747 " gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
749 programCollection.glslSources.add("tese")
750 << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
754 const std::string geometry =
756 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
757 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
758 "layout(${TOPOLOGY}) in;\n"
759 "layout(points, max_vertices = 1) out;\n"
760 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
764 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
766 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
771 " uvec4 mask = subgroupBallot(true);\n"
773 " result[gl_PrimitiveIDIn] = tempResult;\n"
774 " gl_Position = gl_in[0].gl_Position;\n"
778 subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
779 programCollection.glslSources);
783 const std::string fragment =
785 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
786 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
787 "layout(location = 0) out uint result;\n"
788 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
790 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
794 " uvec4 mask = subgroupBallot(true);\n"
796 " result = tempResult;\n"
798 programCollection.glslSources.add("fragment")
799 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
801 subgroups::addNoSubgroupShader(programCollection);
805 void supportedCheck (Context& context, CaseDefinition caseDef)
807 if (!subgroups::isSubgroupSupported(context))
808 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
810 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_ARITHMETIC_BIT))
812 TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
815 if (subgroups::isDoubleFormat(caseDef.format) &&
816 !subgroups::isDoubleSupportedForDevice(context))
818 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
822 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
824 if (!subgroups::areSubgroupOperationsSupportedForStage(
825 context, caseDef.shaderStage))
827 if (subgroups::areSubgroupOperationsRequiredForStage(
828 caseDef.shaderStage))
830 return tcu::TestStatus::fail(
832 subgroups::getShaderStageName(caseDef.shaderStage) +
833 " is required to support subgroup operations!");
837 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
841 subgroups::SSBOData inputData;
842 inputData.format = caseDef.format;
843 inputData.numElements = subgroups::maxSupportedSubgroupSize();
844 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
846 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
847 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
848 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
849 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
850 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
851 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
852 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
853 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
855 TCU_THROW(InternalError, "Unhandled shader stage");
858 bool checkShaderStages (Context& context, const CaseDefinition& caseDef)
860 if (!subgroups::areSubgroupOperationsSupportedForStage(
861 context, caseDef.shaderStage))
863 if (subgroups::areSubgroupOperationsRequiredForStage(
864 caseDef.shaderStage))
870 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
876 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
878 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
880 if(!checkShaderStages(context,caseDef))
882 return tcu::TestStatus::fail(
884 subgroups::getShaderStageName(caseDef.shaderStage) +
885 " is required to support subgroup operations!");
887 subgroups::SSBOData inputData;
888 inputData.format = caseDef.format;
889 inputData.numElements = subgroups::maxSupportedSubgroupSize();
890 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
892 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
896 VkPhysicalDeviceSubgroupProperties subgroupProperties;
897 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
898 subgroupProperties.pNext = DE_NULL;
900 VkPhysicalDeviceProperties2 properties;
901 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
902 properties.pNext = &subgroupProperties;
904 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
906 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage & subgroupProperties.supportedStages);
908 if ( VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
910 if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
911 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
913 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
916 if ((VkShaderStageFlagBits)0u == stages)
917 TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
919 subgroups::SSBOData inputData;
920 inputData.format = caseDef.format;
921 inputData.numElements = subgroups::maxSupportedSubgroupSize();
922 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
923 inputData.binding = 4u;
924 inputData.stages = stages;
926 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData,
927 1, checkVertexPipelineStages, stages);
936 tcu::TestCaseGroup* createSubgroupsArithmeticTests(tcu::TestContext& testCtx)
938 de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
939 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
940 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
941 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
942 de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
943 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
945 const VkShaderStageFlags stages[] =
947 VK_SHADER_STAGE_VERTEX_BIT,
948 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
949 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
950 VK_SHADER_STAGE_GEOMETRY_BIT,
953 const VkFormat formats[] =
955 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
956 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
957 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
958 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
959 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
960 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
961 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
962 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
963 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
966 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
968 const VkFormat format = formats[formatIndex];
970 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
973 bool isFloat = false;
979 case VK_FORMAT_R32_SFLOAT:
980 case VK_FORMAT_R32G32_SFLOAT:
981 case VK_FORMAT_R32G32B32_SFLOAT:
982 case VK_FORMAT_R32G32B32A32_SFLOAT:
983 case VK_FORMAT_R64_SFLOAT:
984 case VK_FORMAT_R64G64_SFLOAT:
985 case VK_FORMAT_R64G64B64_SFLOAT:
986 case VK_FORMAT_R64G64B64A64_SFLOAT:
989 case VK_FORMAT_R8_USCALED:
990 case VK_FORMAT_R8G8_USCALED:
991 case VK_FORMAT_R8G8B8_USCALED:
992 case VK_FORMAT_R8G8B8A8_USCALED:
997 bool isBitwiseOp = false;
1004 case OPTYPE_INCLUSIVE_AND:
1005 case OPTYPE_EXCLUSIVE_AND:
1007 case OPTYPE_INCLUSIVE_OR:
1008 case OPTYPE_EXCLUSIVE_OR:
1010 case OPTYPE_INCLUSIVE_XOR:
1011 case OPTYPE_EXCLUSIVE_XOR:
1016 if (isFloat && isBitwiseOp)
1018 // Skip float with bitwise category.
1022 if (isBool && !isBitwiseOp)
1024 // Skip bool when its not the bitwise category.
1027 std::string op = getOpTypeName(opTypeIndex);
1030 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
1031 addFunctionCaseWithPrograms(computeGroup.get(),
1032 de::toLower(op) + "_" +
1033 subgroups::getFormatNameForGLSL(format),
1034 "", supportedCheck, initPrograms, test, caseDef);
1038 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format};
1039 addFunctionCaseWithPrograms(graphicGroup.get(),
1040 de::toLower(op) + "_" +
1041 subgroups::getFormatNameForGLSL(format),
1042 "", supportedCheck, initPrograms, test, caseDef);
1045 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
1047 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
1048 addFunctionCaseWithPrograms(framebufferGroup.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) +
1049 "_" + getShaderStageName(caseDef.shaderStage), "",
1050 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
1055 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1056 testCtx, "arithmetic", "Subgroup arithmetic category tests"));
1058 group->addChild(graphicGroup.release());
1059 group->addChild(computeGroup.release());
1060 group->addChild(framebufferGroup.release());
1062 return group.release();