Simplify code for subgroup builtin mask tests
authorGraeme Leese <gleese@broadcom.com>
Thu, 10 Dec 2020 11:49:36 +0000 (11:49 +0000)
committerAlexander Galazin <Alexander.Galazin@arm.com>
Fri, 18 Dec 2020 08:10:35 +0000 (08:10 +0000)
A lot of code was being duplicated, which is now shared. The shaders
were calculating bitCount by looping over the bits, which is less
efficient than using the intrinsic.

The new code is shorter and should be easier to understand and faster to
execute, but there is no functional change.

Components: Vulkan
Affects: dEQP-VK.subgroups.builtin_mask_var.*

Change-Id: I6ef0d607423aa5cccce17aefeaac8cc1055ec488

external/vulkancts/modules/vulkan/subgroups/vktSubgroupsBuiltinMaskVarTests.cpp

index 4c1d2f6..92da08e 100755 (executable)
@@ -146,74 +146,32 @@ std::string varSubgroupMask (const CaseDefinition& caseDef)
 
 std::string subgroupMask (const CaseDefinition& caseDef)
 {
+       std::string comp;
+       if (caseDef.varName == "gl_SubgroupEqMask")
+               comp = "==";
+       else if (caseDef.varName == "gl_SubgroupGeMask")
+               comp = ">=";
+       else if (caseDef.varName == "gl_SubgroupGtMask")
+               comp = ">";
+       else if (caseDef.varName == "gl_SubgroupLeMask")
+               comp = "<=";
+       else if (caseDef.varName == "gl_SubgroupLtMask")
+               comp = "<";
+
        std::ostringstream bdy;
 
        bdy << "  uint tempResult = 0x1;\n"
-               << "  uint bit        = 0x1;\n"
-               << "  uint bitCount   = 0x0;\n"
                << "  uvec4 mask = subgroupBallot(true);\n"
                << "  const uvec4 var = " << caseDef.varName << ";\n"
                << "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
-               << "  {\n";
-
-       if ("gl_SubgroupEqMask" == caseDef.varName)
-       {
-               bdy << "    if ((i == gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
-                       << "    {\n"
-                       << "      tempResult = 0;\n"
-                       << "    }\n";
-       }
-       else if ("gl_SubgroupGeMask" == caseDef.varName)
-       {
-               bdy << "    if ((i >= gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
-                       << "    {\n"
-                       << "      tempResult = 0;\n"
-                       << "    }\n";
-       }
-       else if ("gl_SubgroupGtMask" == caseDef.varName)
-       {
-               bdy << "    if ((i > gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
-                       << "    {\n"
-                       << "      tempResult = 0;\n"
-                       << "    }\n";
-       }
-       else if ("gl_SubgroupLeMask" == caseDef.varName)
-       {
-               bdy << "    if ((i <= gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
-                       << "    {\n"
-                       << "      tempResult = 0;\n"
-                       << "    }\n";
-       }
-       else if ("gl_SubgroupLtMask" == caseDef.varName)
-       {
-               bdy << "    if ((i < gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
-                       << "    {\n"
-                       << "      tempResult = 0;\n"
-                       << "    }\n";
-       }
-
-       bdy << "  }\n"
-               << "  for (uint i = 0; i < 32; i++)\n"
                << "  {\n"
-               << "    if ((var.x & bit) > 0)\n"
-               << "    {\n"
-               << "      bitCount++;\n"
-               << "    }\n"
-               << "    if ((var.y & bit) > 0)\n"
-               << "    {\n"
-               << "      bitCount++;\n"
-               << "    }\n"
-               << "    if ((var.z & bit) > 0)\n"
-               << "    {\n"
-               << "      bitCount++;\n"
-               << "    }\n"
-               << "    if ((var.w & bit) > 0)\n"
+               << "    if ((i " << comp << " gl_SubgroupInvocationID) ^^ subgroupBallotBitExtract(var, i))\n"
                << "    {\n"
-               << "      bitCount++;\n"
+               << "      tempResult = 0;\n"
                << "    }\n"
-               << "    bit = bit<<1;\n"
                << "  }\n"
-               << "  if (subgroupBallotBitCount(var) != bitCount)\n"
+               << "  uint c = bitCount(var.x) + bitCount(var.y) + bitCount(var.z) + bitCount(var.w);\n"
+               << "  if (subgroupBallotBitCount(var) != c)\n"
                << "  {\n"
                << "    tempResult = 0;\n"
                << "  }\n";