Use a single test source for quad tests
authorGraeme Leese <gleese@broadcom.com>
Fri, 6 Sep 2019 10:39:27 +0000 (11:39 +0100)
committerAlexander Galazin <Alexander.Galazin@arm.com>
Wed, 9 Oct 2019 07:34:24 +0000 (03:34 -0400)
Rather than repeat this through all the various combinations, just keep
one copy of the test source, so that changing it is not such a massive
pain.

Also, use a single source for the extension enables in the tests.

Components: Vulkan
Affects: dEQP-VK.subgroups.quad.*

Change-Id: I26ca2b035483aa47e021de7aefb3c94879c497d0

external/vulkancts/modules/vulkan/subgroups/vktSubgroupsQuadTests.cpp

index 84488e4..c27d883 100755 (executable)
@@ -85,28 +85,66 @@ struct CaseDefinition
        de::SharedPtr<bool>     geometryPointSizeSupported;
 };
 
+std::string GetExtHeader(VkFormat format)
+{
+       return  "#extension GL_KHR_shader_subgroup_quad: enable\n"
+                       "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
+                       subgroups::getAdditionalExtensionForFormat(format);
+}
+
+std::string GetTestSrc(const CaseDefinition &caseDef)
+{
+       const std::string swapTable[OPTYPE_LAST] = {
+               "",
+               "  const uint swapTable[4] = {1, 0, 3, 2};\n",
+               "  const uint swapTable[4] = {2, 3, 0, 1};\n",
+               "  const uint swapTable[4] = {3, 2, 1, 0};\n",
+       };
+
+       std::ostringstream testSrc;
+       testSrc << "  uvec4 mask = subgroupBallot(true);\n"
+                       << swapTable[caseDef.opType];
+       if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
+       {
+               testSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
+                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
+                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
+       }
+       else
+       {
+               testSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
+                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
+                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
+       }
+       testSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
+                       << "  {\n"
+                       << "    tempRes = (op == data[otherID]) ? 1 : 0;\n"
+                       << "  }\n"
+                       << "  else\n"
+                       << "  {\n"
+                       << "    tempRes = 1;\n" // Invocation we read from was inactive, so we can't verify results!
+                       << "  }\n";
+
+       return testSrc.str();
+}
+
 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
 {
        const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
-       std::string                     swapTable[OPTYPE_LAST];
 
        subgroups::setFragmentShaderFrameBuffer(programCollection);
 
        if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
                subgroups::setVertexShaderFrameBuffer(programCollection);
 
-       swapTable[OPTYPE_QUAD_BROADCAST] = "";
-       swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
-       swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
-       swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
+       std::string extHeader = GetExtHeader(caseDef.format);
+       std::string testSrc = GetTestSrc(caseDef);
 
        if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
        {
                std::ostringstream      vertexSrc;
                vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
-                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                       << subgroups::getAdditionalExtensionForFormat(caseDef.format)
+                       << extHeader.c_str()
                        << "layout(location = 0) in highp vec4 in_position;\n"
                        << "layout(location = 0) out float result;\n"
                        << "layout(set = 0, binding = 0) uniform Buffer1\n"
@@ -116,30 +154,9 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                        << "\n"
                        << "void main (void)\n"
                        << "{\n"
-                       << "  uvec4 mask = subgroupBallot(true);\n"
-                       << swapTable[caseDef.opType];
-
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-
-               vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                       << "  {\n"
-                       << "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
-                       << "  }\n"
+                       << "  uint tempRes;\n"
+                       << testSrc
+                       << "  result = float(tempRes);\n"
                        << "  gl_Position = in_position;\n"
                        << "  gl_PointSize = 1.0f;\n"
                        << "}\n";
@@ -151,9 +168,7 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                std::ostringstream geometry;
 
                geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
-                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                       << subgroups::getAdditionalExtensionForFormat(caseDef.format)
+                       << extHeader.c_str()
                        << "layout(points) in;\n"
                        << "layout(points, max_vertices = 1) out;\n"
                        << "layout(location = 0) out float out_color;\n"
@@ -165,30 +180,9 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                        << "\n"
                        << "void main (void)\n"
                        << "{\n"
-                       << "  uvec4 mask = subgroupBallot(true);\n"
-                       << swapTable[caseDef.opType];
-
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-
-               geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                       << "  {\n"
-                       << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
-                       << "  }\n"
+                       << "  uint tempRes;\n"
+                       << testSrc
+                       << "  out_color = float(tempRes);\n"
                        << "  gl_Position = gl_in[0].gl_Position;\n"
                        << (*caseDef.geometryPointSizeSupported ? "  gl_PointSize = gl_in[0].gl_PointSize;\n" : "")
                        << "  EmitVertex();\n"
@@ -203,9 +197,7 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                std::ostringstream controlSource;
 
                controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
-                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                       << subgroups::getAdditionalExtensionForFormat(caseDef.format)
+                       << extHeader.c_str()
                        << "layout(vertices = 2) out;\n"
                        << "layout(location = 0) out float out_color[];\n"
                        << "layout(set = 0, binding = 0) uniform Buffer1\n"
@@ -220,30 +212,9 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                        << "    gl_TessLevelOuter[0] = 1.0f;\n"
                        << "    gl_TessLevelOuter[1] = 1.0f;\n"
                        << "  }\n"
-                       << "  uvec4 mask = subgroupBallot(true);\n"
-                       << swapTable[caseDef.opType];
-
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-
-               controlSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                       << "  {\n"
-                       << "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
-                       << "  }\n"
+                       << "  uint tempRes;\n"
+                       << testSrc
+                       << "  out_color[gl_InvocationID] = float(tempRes);\n"
                        << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
                        << "}\n";
 
@@ -255,9 +226,7 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
        {
                ostringstream evaluationSource;
                evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
-                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                       << subgroups::getAdditionalExtensionForFormat(caseDef.format)
+                       << extHeader.c_str()
                        << "layout(isolines, equal_spacing, ccw ) in;\n"
                        << "layout(location = 0) out float out_color;\n"
                        << "layout(set = 0, binding = 0) uniform Buffer1\n"
@@ -267,30 +236,9 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
                        << "\n"
                        << "void main (void)\n"
                        << "{\n"
-                       << "  uvec4 mask = subgroupBallot(true);\n"
-                       << swapTable[caseDef.opType];
-
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-
-               evaluationSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                       << "  {\n"
-                       << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
-                       << "  }\n"
+                       << "  uint tempRes;\n"
+                       << testSrc
+                       << "  out_color = float(tempRes);\n"
                        << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
                        << "}\n";
 
@@ -306,20 +254,15 @@ void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefiniti
 
 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
 {
-       std::string swapTable[OPTYPE_LAST];
-       swapTable[OPTYPE_QUAD_BROADCAST] = "";
-       swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
-       swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
-       swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
+       std::string extHeader = GetExtHeader(caseDef.format);
+       std::string sourceType = GetTestSrc(caseDef);
 
        if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
        {
                std::ostringstream src;
 
                src << "#version 450\n"
-                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                       << subgroups::getAdditionalExtensionForFormat(caseDef.format)
+                       << extHeader.c_str()
                        << "layout (local_size_x_id = 0, local_size_y_id = 1, "
                        "local_size_z_id = 2) in;\n"
                        << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
@@ -337,31 +280,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                        << "  highp uint offset = globalSize.x * ((globalSize.y * "
                        "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
                        "gl_GlobalInvocationID.x;\n"
-                       << "  uvec4 mask = subgroupBallot(true);\n"
-                       << swapTable[caseDef.opType];
-
-
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-
-               src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                       << "  {\n"
-                       << "    result[offset] = (op == data[otherID]) ? 1 : 0;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                       << "  }\n"
+                       << "  uint tempRes;\n"
+                       << sourceType
+                       << "  result[offset] = tempRes;\n"
                        << "}\n";
 
                programCollection.glslSources.add("comp")
@@ -369,27 +290,10 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
        }
        else
        {
-               std::ostringstream src;
-               if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
-               {
-                       src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
-               }
-               else
-               {
-                       src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
-                               << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
-                               << "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
-               }
-               const string sourceType = src.str();
-
                {
                        const string vertex =
                                "#version 450\n"
-                               "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                               "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                               + subgroups::getAdditionalExtensionForFormat(caseDef.format) +
+                               + extHeader +
                                "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
                                "{\n"
                                "  uint result[];\n"
@@ -401,17 +305,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                                "\n"
                                "void main (void)\n"
                                "{\n"
-                               "  uvec4 mask = subgroupBallot(true);\n"
-                               + swapTable[caseDef.opType]
+                               "  uint tempRes;\n"
                                + sourceType +
-                               "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                               "  {\n"
-                               "    result[gl_VertexIndex] = (op == data[otherID]) ? 1 : 0;\n"
-                               "  }\n"
-                               "  else\n"
-                               "  {\n"
-                               "    result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                               "  }\n"
+                               "  result[gl_VertexIndex] = tempRes;\n"
                                "  float pixelSize = 2.0f/1024.0f;\n"
                                "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
                                "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
@@ -424,9 +320,7 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                {
                        const string tesc =
                                "#version 450\n"
-                               "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                               "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                               + subgroups::getAdditionalExtensionForFormat(caseDef.format) +
+                               + extHeader +
                                "layout(vertices=1) out;\n"
                                "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
                                "{\n"
@@ -439,17 +333,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                                "\n"
                                "void main (void)\n"
                                "{\n"
-                               "  uvec4 mask = subgroupBallot(true);\n"
-                               + swapTable[caseDef.opType]
+                               "  uint tempRes;\n"
                                + sourceType +
-                               "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
-                               "  }\n"
-                               "  else\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                               "  }\n"
+                               "  result[gl_PrimitiveID] = tempRes;\n"
                                "  if (gl_InvocationID == 0)\n"
                                "  {\n"
                                "    gl_TessLevelOuter[0] = 1.0f;\n"
@@ -464,9 +350,7 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                {
                        const string tese =
                                "#version 450\n"
-                               "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                               "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                               + subgroups::getAdditionalExtensionForFormat(caseDef.format) +
+                               + extHeader +
                                "layout(isolines) in;\n"
                                "layout(set = 0, binding = 2, std430)  buffer Buffer1\n"
                                "{\n"
@@ -479,17 +363,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                                "\n"
                                "void main (void)\n"
                                "{\n"
-                               "  uvec4 mask = subgroupBallot(true);\n"
-                               + swapTable[caseDef.opType]
+                               "  uint tempRes;\n"
                                + sourceType +
-                               "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
-                               "  }\n"
-                               "  else\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                               "  }\n"
+                               "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempRes;\n"
                                "  float pixelSize = 2.0f/1024.0f;\n"
                                "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
                                "}\n";
@@ -500,9 +376,7 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                {
                        const string geometry =
                                "#version 450\n"
-                               "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                               "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                               + subgroups::getAdditionalExtensionForFormat(caseDef.format) +
+                               + extHeader +
                                "layout(${TOPOLOGY}) in;\n"
                                "layout(points, max_vertices = 1) out;\n"
                                "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
@@ -516,17 +390,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                                "\n"
                                "void main (void)\n"
                                "{\n"
-                               "  uvec4 mask = subgroupBallot(true);\n"
-                               + swapTable[caseDef.opType]
+                               "  uint tempRes;\n"
                                + sourceType +
-                               "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
-                               "  }\n"
-                               "  else\n"
-                               "  {\n"
-                               "    result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                               "  }\n"
+                               "  result[gl_PrimitiveIDIn] = tempRes;\n"
                                "  gl_Position = gl_in[0].gl_Position;\n"
                                "  EmitVertex();\n"
                                "  EndPrimitive();\n"
@@ -538,9 +404,7 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                {
                        const string fragment =
                                "#version 450\n"
-                               "#extension GL_KHR_shader_subgroup_quad: enable\n"
-                               "#extension GL_KHR_shader_subgroup_ballot: enable\n"
-                               + subgroups::getAdditionalExtensionForFormat(caseDef.format) +
+                               + extHeader +
                                "layout(location = 0) out uint result;\n"
                                "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
                                "{\n"
@@ -548,17 +412,9 @@ void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
                                "};\n"
                                "void main (void)\n"
                                "{\n"
-                               "  uvec4 mask = subgroupBallot(true);\n"
-                               + swapTable[caseDef.opType]
+                               "  uint tempRes;\n"
                                + sourceType +
-                               "  if (subgroupBallotBitExtract(mask, otherID))\n"
-                               "  {\n"
-                               "    result = (op == data[otherID]) ? 1 : 0;\n"
-                               "  }\n"
-                               "  else\n"
-                               "  {\n"
-                               "    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
-                               "  }\n"
+                               "  result = tempRes;\n"
                                "}\n";
                        programCollection.glslSources.add("fragment")
                                << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);