<< " uint sgInvocation = gl_SubgroupInvocationID;\n";
}
- if (OPTYPE_BROADCAST == caseDef.opType)
+ const std::string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
+
+ if (caseDef.opType == OPTYPE_BROADCAST)
{
bdy << " uint tempResult = 0x3;\n";
for (int i = 0; i < max; i++)
{
bdy << " {\n"
<< " const uint id = "<< i << ";\n"
- << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
- << broadcast << "(data1[sgInvocation], id);\n"
+ << " " << fmt << " op = " << broadcast << "(data1[sgInvocation], id);\n"
<< " if ((id < sgSize) && subgroupBallotBitExtract(mask, id))\n"
<< " {\n"
<< " if (op != data1[id])\n"
<< " }\n";
}
}
- else if (OPTYPE_BROADCAST_NONCONST == caseDef.opType)
+ else if (caseDef.opType == OPTYPE_BROADCAST_NONCONST)
{
+ const std::string validate = " if (subgroupBallotBitExtract(mask, id) && op != data1[id])\n"
+ " tempResult = 0;\n";
+
bdy << " uint tempResult = 0x3;\n"
<< " for (uint id = 0; id < sgSize; id++)\n"
<< " {\n"
- << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
- << broadcast << "(data1[sgInvocation], id);\n"
- << " if (subgroupBallotBitExtract(mask, id))\n"
- << " {\n"
- << " if (op != data1[id])\n"
- << " {\n"
- << " tempResult = 0;\n"
- << " }\n"
- << " }\n"
+ << " " << fmt << " op = " << broadcast << "(data1[sgInvocation], id);\n"
+ << validate
+ << " }\n"
+ << " // Test lane id that is only uniform across active lanes\n"
+ << " if (sgInvocation >= sgSize / 2)\n"
+ << " {\n"
+ << " uint id = sgInvocation & ~((sgSize / 2) - 1);\n"
+ << " " << fmt << " op = " << broadcast << "(data1[sgInvocation], id);\n"
+ << validate
<< " }\n";
}
else