Combine ARB and Core subgroupBroadcast code
authorGraeme Leese <gleese@broadcom.com>
Mon, 2 Sep 2019 16:49:51 +0000 (17:49 +0100)
committerAlexander Galazin <Alexander.Galazin@arm.com>
Wed, 18 Sep 2019 15:32:11 +0000 (11:32 -0400)
The ARB and core tests are attempting to do the same thing but using
slightly different names for the functions. Parameterise the names so
that we only have to maintain one copy of the test logic.

Also, make it so that we only have to write out the extensions that we
need once.

Components: Vulkan
Affects: dEQP-VK.subgroups.ballot_broadcast.*

Change-Id: I259249abba2b352dbd217aa4c26100ab7dbee637

external/vulkancts/modules/vulkan/subgroups/vktSubgroupsBallotBroadcastTests.cpp

index 1e168ff..3575141 100755 (executable)
@@ -83,79 +83,40 @@ std::string getBodySource(CaseDefinition caseDef)
 {
        std::ostringstream bdy;
 
-       bdy << "  uvec4 mask = subgroupBallot(true);\n";
-       bdy << "  uint tempResult = 0;\n";
-
-       if (OPTYPE_BROADCAST == caseDef.opType)
+       std::string broadcast;
+       std::string broadcastFirst;
+       int max;
+       if (caseDef.extShaderSubGroupBallotTests)
        {
-               bdy     << "  tempResult = 0x3;\n";
-               for (int i = 0; i < (int)subgroups::maxSupportedSubgroupSize(); i++)
-               {
-                       bdy << "  {\n"
-                       << "    const uint id = "<< i << ";\n"
-                       << "    " << subgroups::getFormatNameForGLSL(caseDef.format)
-                       << " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
-                       << "    if ((id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
-                       << "    {\n"
-                       << "      if (op != data1[id])\n"
-                       << "      {\n"
-                       << "        tempResult = 0;\n"
-                       << "      }\n"
-                       << "    }\n"
-                       << "  }\n";
-               }
+               broadcast               = "readInvocationARB";
+               broadcastFirst  = "readFirstInvocationARB";
+               max = 64;
+
+               bdy << "  uint64_t mask = ballotARB(true);\n";
+               bdy << "  uint sgSize = gl_SubGroupSizeARB;\n";
+               bdy << "  uint sgInvocation = gl_SubGroupInvocationARB;\n";
        }
        else
        {
-               bdy     << "  uint firstActive = 0;\n"
-                       << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
-                       << "  {\n"
-                       << "    if (subgroupBallotBitExtract(mask, i))\n"
-                       << "    {\n"
-                       << "      firstActive = i;\n"
-                       << "      break;\n"
-                       << "    }\n"
-                       << "  }\n"
-                       << "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
-                       << "  // make the firstActive invocation inactive now\n"
-                       << "  if (firstActive == gl_SubgroupInvocationID)\n"
-                       << "  {\n"
-                       << "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
-                       << "    {\n"
-                       << "      if (subgroupBallotBitExtract(mask, i))\n"
-                       << "      {\n"
-                       << "        firstActive = i;\n"
-                       << "        break;\n"
-                       << "      }\n"
-                       << "    }\n"
-                       << "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
-                       << "  }\n"
-                       << "  else\n"
-                       << "  {\n"
-                       << "    // the firstActive invocation didn't partake in the second result so set it to true\n"
-                       << "    tempResult |= 0x2;\n"
-                       << "  }\n";
-       }
-   return bdy.str();
-}
+               broadcast               = "subgroupBroadcast";
+               broadcastFirst  = "subgroupBroadcastFirst";
+               max = (int)subgroups::maxSupportedSubgroupSize();
 
-std::string getBodySourceARB(CaseDefinition caseDef)
-{
-       std::ostringstream bdy;
-
-       bdy << "  uint64_t mask = ballotARB(true);\n";
-       bdy << "  uint tempResult = 0;\n";
+               bdy << "  uvec4 mask = subgroupBallot(true);\n";
+               bdy << "  uint sgSize = gl_SubgroupSize;\n";
+               bdy << "  uint sgInvocation = gl_SubgroupInvocationID;\n";
+       }
 
        if (OPTYPE_BROADCAST == caseDef.opType)
        {
-               bdy     << "  tempResult = 0x3;\n";
-               for (int i = 0; i < 64; i++)
+               bdy     << "  uint tempResult = 0x3;\n";
+               for (int i = 0; i < max; i++)
                {
                        bdy << "  {\n"
                        << "    const uint id = "<< i << ";\n"
-                       << "    " << subgroups::getFormatNameForGLSL(caseDef.format)
-                       << " op = readInvocationARB(data1[gl_SubGroupInvocationARB], id);\n"
-                       << "    if ((id < gl_SubGroupSizeARB) && subgroupBallotBitExtract(mask, id))\n"
+                       << "    " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
+                               << broadcast << "(data1[sgInvocation], id);\n"
+                       << "    if ((id < sgSize) && subgroupBallotBitExtract(mask, id))\n"
                        << "    {\n"
                        << "      if (op != data1[id])\n"
                        << "      {\n"
@@ -167,8 +128,9 @@ std::string getBodySourceARB(CaseDefinition caseDef)
        }
        else
        {
-               bdy     << "  uint firstActive = 0;\n"
-                       << "  for (uint i = 0; i < gl_SubGroupSizeARB; i++)\n"
+               bdy << "  uint tempResult = 0;\n"
+                       << "  uint firstActive = 0;\n"
+                       << "  for (uint i = 0; i < sgSize; i++)\n"
                        << "  {\n"
                        << "    if (subgroupBallotBitExtract(mask, i))\n"
                        << "    {\n"
@@ -176,11 +138,11 @@ std::string getBodySourceARB(CaseDefinition caseDef)
                        << "      break;\n"
                        << "    }\n"
                        << "  }\n"
-                       << "  tempResult |= (readFirstInvocationARB(data1[gl_SubGroupInvocationARB]) == data1[firstActive]) ? 0x1 : 0;\n"
+                       << "  tempResult |= (" << broadcastFirst << "(data1[sgInvocation]) == data1[firstActive]) ? 0x1 : 0;\n"
                        << "  // make the firstActive invocation inactive now\n"
-                       << "  if (firstActive == gl_SubGroupInvocationARB)\n"
+                       << "  if (firstActive == sgInvocation)\n"
                        << "  {\n"
-                       << "    for (uint i = 0; i < gl_SubGroupSizeARB; i++)\n"
+                       << "    for (uint i = 0; i < sgSize; i++)\n"
                        << "    {\n"
                        << "      if (subgroupBallotBitExtract(mask, i))\n"
                        << "      {\n"
@@ -188,7 +150,7 @@ std::string getBodySourceARB(CaseDefinition caseDef)
                        << "        break;\n"
                        << "      }\n"
                        << "    }\n"
-                       << "    tempResult |= (readFirstInvocationARB(data1[gl_SubGroupInvocationARB]) == data1[firstActive]) ? 0x2 : 0;\n"
+                       << "    tempResult |= (" << broadcastFirst << "(data1[sgInvocation]) == data1[firstActive]) ? 0x2 : 0;\n"
                        << "  }\n"
                        << "  else\n"
                        << "  {\n"
@@ -221,14 +183,17 @@ std::string getHelperFunctionARB(CaseDefinition caseDef)
 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
 {
        const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
-       const string extensionHeader =  (caseDef.extShaderSubGroupBallotTests ? "#extension GL_ARB_shader_ballot: enable\n#extension GL_KHR_shader_subgroup_basic: enable\n#extension GL_ARB_gpu_shader_int64: enable\n" : "#extension GL_KHR_shader_subgroup_ballot: enable\n");
+       const string extensionHeader =  (caseDef.extShaderSubGroupBallotTests ? "#extension GL_ARB_shader_ballot: enable\n"
+                                                                                                                                                       "#extension GL_KHR_shader_subgroup_basic: enable\n"
+                                                                                                                                                       "#extension GL_ARB_gpu_shader_int64: enable\n"
+                                                                                                                                               :       "#extension GL_KHR_shader_subgroup_ballot: enable\n");
 
        subgroups::setFragmentShaderFrameBuffer(programCollection);
 
        if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
                subgroups::setVertexShaderFrameBuffer(programCollection);
 
-       std::string bdyStr = (caseDef.extShaderSubGroupBallotTests ? getBodySourceARB(caseDef) : getBodySource(caseDef));
+       std::string bdyStr = getBodySource(caseDef);
        std::string helperStrARB = getHelperFunctionARB(caseDef);
 
        if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
@@ -344,10 +309,13 @@ void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinitio
 
 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
 {
-       std::string bdyStr = caseDef.extShaderSubGroupBallotTests ? getBodySourceARB(caseDef) : getBodySource(caseDef);
+       std::string bdyStr = getBodySource(caseDef);
        std::string helperStrARB = getHelperFunctionARB(caseDef);
 
-       const string extensionHeader =  (caseDef.extShaderSubGroupBallotTests ? "#extension GL_ARB_shader_ballot: enable\n#extension GL_KHR_shader_subgroup_basic: enable\n#extension GL_ARB_gpu_shader_int64: enable\n" : "#extension GL_KHR_shader_subgroup_ballot: enable\n");
+       const string extensionHeader =  (caseDef.extShaderSubGroupBallotTests ? "#extension GL_ARB_shader_ballot: enable\n"
+                                                                                                                                                       "#extension GL_KHR_shader_subgroup_basic: enable\n"
+                                                                                                                                                       "#extension GL_ARB_gpu_shader_int64: enable\n"
+                                                                                                                                               :       "#extension GL_KHR_shader_subgroup_ballot: enable\n");
 
        if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
        {