// limitations under the License.
#include "folding_rules.h"
+#include "latest_version_glsl_std_450_header.h"
namespace spvtools {
namespace opt {
const uint32_t kExtractCompositeIdInIdx = 0;
const uint32_t kInsertObjectIdInIdx = 0;
const uint32_t kInsertCompositeIdInIdx = 1;
+const uint32_t kExtInstSetIdInIdx = 0;
+const uint32_t kExtInstInstructionInIdx = 1;
+const uint32_t kFMixXIdInIdx = 2;
+const uint32_t kFMixYIdInIdx = 3;
FoldingRule IntMultipleBy1() {
return [](ir::Instruction* inst,
}
};
}
+
+enum class FloatConstantKind { Unknown, Zero, One };
+
+FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
+ if (constant == nullptr) {
+ return FloatConstantKind::Unknown;
+ }
+
+ if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
+ const std::vector<const analysis::Constant*>& components =
+ vc->GetComponents();
+ assert(!components.empty());
+
+ FloatConstantKind kind = getFloatConstantKind(components[0]);
+
+ for (size_t i = 1; i < components.size(); ++i) {
+ if (getFloatConstantKind(components[i]) != kind) {
+ return FloatConstantKind::Unknown;
+ }
+ }
+
+ return kind;
+ } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
+ double value = (fc->type()->AsFloat()->width() == 64) ? fc->GetDoubleValue()
+ : fc->GetFloatValue();
+
+ if (value == 0.0) {
+ return FloatConstantKind::Zero;
+ } else if (value == 1.0) {
+ return FloatConstantKind::One;
+ } else {
+ return FloatConstantKind::Unknown;
+ }
+ } else {
+ return FloatConstantKind::Unknown;
+ }
+}
+
+FoldingRule RedundantFAdd() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
+ assert(constants.size() == 2);
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+ FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+ if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+ {inst->GetSingleWordInOperand(
+ kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+FoldingRule RedundantFSub() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
+ assert(constants.size() == 2);
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+ FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+ if (kind0 == FloatConstantKind::Zero) {
+ inst->SetOpcode(SpvOpFNegate);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
+ return true;
+ }
+
+ if (kind1 == FloatConstantKind::Zero) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+FoldingRule RedundantFMul() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
+ assert(constants.size() == 2);
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+ FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+ if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+ {inst->GetSingleWordInOperand(
+ kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
+ return true;
+ }
+
+ if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+ {inst->GetSingleWordInOperand(
+ kind0 == FloatConstantKind::One ? 1 : 0)}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+FoldingRule RedundantFDiv() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
+ assert(constants.size() == 2);
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+ FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+ if (kind0 == FloatConstantKind::Zero) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+ return true;
+ }
+
+ if (kind1 == FloatConstantKind::One) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+FoldingRule RedundantFMix() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpExtInst &&
+ "Wrong opcode. Should be OpExtInst.");
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ uint32_t instSetId =
+ inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+ if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
+ inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
+ GLSLstd450FMix) {
+ assert(constants.size() == 5);
+
+ FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
+
+ if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID,
+ {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
+ ? kFMixXIdInIdx
+ : kFMixYIdInIdx)}}});
+ return true;
+ }
+ }
+
+ return false;
+ };
+}
+
} // namespace
spvtools::opt::FoldingRules::FoldingRules() {
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
+ rules_[SpvOpExtInst].push_back(RedundantFMix());
+
+ rules_[SpvOpFAdd].push_back(RedundantFAdd());
+ rules_[SpvOpFDiv].push_back(RedundantFDiv());
+ rules_[SpvOpFMul].push_back(RedundantFMul());
+ rules_[SpvOpFSub].push_back(RedundantFSub());
+
rules_[SpvOpIMul].push_back(IntMultipleBy1());
rules_[SpvOpPhi].push_back(RedundantPhi());
rules_[SpvOpSelect].push_back(RedundantSelect());
}
+
} // namespace opt
} // namespace spvtools
#define TRUE_ID 101
#define VEC2_0_ID 102
#define INT_7_ID 103
+#define FLOAT_0_ID 104
+#define DOUBLE_0_ID 105
+#define VEC4_0_ID 106
+#define DVEC4_0_ID 106
const std::string& Header() {
static const std::string header = R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
%uint = OpTypeInt 32 1
%v2int = OpTypeVector %int 2
%v4int = OpTypeVector %int 4
+%v4float = OpTypeVector %float 4
+%v4double = OpTypeVector %double 4
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
%_ptr_bool = OpTypePointer Function %bool
+%_ptr_float = OpTypePointer Function %float
+%_ptr_double = OpTypePointer Function %double
+%_ptr_v4float = OpTypePointer Function %v4float
+%_ptr_v4double = OpTypePointer Function %v4double
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
%float16_1 = OpConstant %float16 1
%float16_2 = OpConstant %float16 2
%float_n1 = OpConstant %float -1
+%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%double_n1 = OpConstant %double -1
+%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_0 = OpConstant %double 0
%double_1 = OpConstant %double 1
%double_2 = OpConstant %double 2
%double_3 = OpConstant %double 3
%float_nan = OpConstant %float -0x1.8p+128
%double_nan = OpConstant %double -0x1.8p+1024
+%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
+%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
+%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
+%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
+%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
)";
return header;
"OpFunctionEnd",
2, INT_0_ID)
));
+
+INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold n + 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Don't fold n - 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFSub %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 2: Don't fold n * 2.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFMul %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 3: Don't fold n / 2.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFDiv %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 4: Fold n + 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFAdd %float %3 %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 5: Fold 0.0 + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFAdd %float %float_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 6: Fold n - 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFSub %float %3 %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 7: Fold n * 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFMul %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 8: Fold 1.0 * n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFMul %float %float_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 9: Fold n / 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFDiv %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 10: Fold n * 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFMul %float %3 %104\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, FLOAT_0_ID),
+ // Test case 11: Fold 0.0 * n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFMul %float %104 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, FLOAT_0_ID),
+ // Test case 12: Fold 0.0 / n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFDiv %float %104 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, FLOAT_0_ID),
+ // Test case 13: Don't fold mix(a, b, 2.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_float Function\n" +
+ "%b = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %a\n" +
+ "%4 = OpLoad %float %b\n" +
+ "%2 = OpExtInst %float %1 FMix %3 %4 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 14: Fold mix(a, b, 0.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_float Function\n" +
+ "%b = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %a\n" +
+ "%4 = OpLoad %float %b\n" +
+ "%2 = OpExtInst %float %1 FMix %3 %4 %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 15: Fold mix(a, b, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_float Function\n" +
+ "%b = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %a\n" +
+ "%4 = OpLoad %float %b\n" +
+ "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 4)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold n + 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFAdd %double %3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Don't fold n - 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFSub %double %3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 2: Don't fold n * 2.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFMul %double %3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 3: Don't fold n / 2.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFDiv %double %3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 4: Fold n + 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFAdd %double %3 %double_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 5: Fold 0.0 + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFAdd %double %double_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 6: Fold n - 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFSub %double %3 %double_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 7: Fold n * 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFMul %double %3 %double_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 8: Fold 1.0 * n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFMul %double %double_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 9: Fold n / 1.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFDiv %double %3 %double_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 10: Fold n * 0.0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFMul %double %3 %105\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, DOUBLE_0_ID),
+ // Test case 11: Fold 0.0 * n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFMul %double %105 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, DOUBLE_0_ID),
+ // Test case 12: Fold 0.0 / n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFDiv %double %105 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, DOUBLE_0_ID),
+ // Test case 13: Don't fold mix(a, b, 2.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_double Function\n" +
+ "%b = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %a\n" +
+ "%4 = OpLoad %double %b\n" +
+ "%2 = OpExtInst %double %1 FMix %3 %4 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 14: Fold mix(a, b, 0.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_double Function\n" +
+ "%b = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %a\n" +
+ "%4 = OpLoad %double %b\n" +
+ "%2 = OpExtInst %double %1 FMix %3 %4 %double_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 15: Fold mix(a, b, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_double Function\n" +
+ "%b = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %a\n" +
+ "%4 = OpLoad %double %b\n" +
+ "%2 = OpExtInst %double %1 FMix %3 %4 %double_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 4)
+));
+
+INSTANTIATE_TEST_CASE_P(FloatVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4float Function\n" +
+ "%3 = OpLoad %v4float %n\n" +
+ "%2 = OpFMul %v4float %3 %v4float_0_0_0_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4float Function\n" +
+ "%3 = OpLoad %v4float %n\n" +
+ "%2 = OpFMul %v4float %3 %106\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, VEC4_0_ID),
+ // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4float Function\n" +
+ "%3 = OpLoad %v4float %n\n" +
+ "%2 = OpFMul %v4float %3 %v4float_1_1_1_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpLoad %v4double %n\n" +
+ "%2 = OpFMul %v4double %3 %v4double_0_0_0_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpLoad %v4double %n\n" +
+ "%2 = OpFMul %v4double %3 %106\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, DVEC4_0_ID),
+ // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpLoad %v4double %n\n" +
+ "%2 = OpFMul %v4double %3 %v4double_1_1_1_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3)
+));
+// clang-format on
+
+using ToNegateFoldingTest =
+ ::testing::TestWithParam<InstructionFoldingCase<uint32_t>>;
+
+TEST_P(ToNegateFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ std::unique_ptr<ir::Instruction> original_inst(inst->Clone(context.get()));
+ bool succeeded = opt::FoldInstruction(inst);
+
+ // Make sure the instruction folded as expected.
+ EXPECT_EQ(inst->result_id(), original_inst->result_id());
+ EXPECT_EQ(inst->type_id(), original_inst->type_id());
+ EXPECT_TRUE((!succeeded) == (tc.expected_result == 0));
+ if (succeeded) {
+ EXPECT_EQ(inst->opcode(), SpvOpFNegate);
+ EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result);
+ } else {
+ EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands());
+ for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
+ EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i));
+ }
+ }
+}
+
// clang-format off
+INSTANTIATE_TEST_CASE_P(FloatRedundantSubFoldingTest, ToNegateFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold 1.0 - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFSub %float %float_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Fold 0.0 - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_float Function\n" +
+ "%3 = OpLoad %float %n\n" +
+ "%2 = OpFSub %float %float_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 2: Don't fold (0,0,0,1) - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4float Function\n" +
+ "%3 = OpLoad %v4float %n\n" +
+ "%2 = OpFSub %v4float %v4float_0_0_0_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 3: Fold (0,0,0,0) - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4float Function\n" +
+ "%3 = OpLoad %v4float %n\n" +
+ "%2 = OpFSub %v4float %v4float_0_0_0_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleRedundantSubFoldingTest, ToNegateFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold 1.0 - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFSub %double %double_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Fold 0.0 - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_double Function\n" +
+ "%3 = OpLoad %double %n\n" +
+ "%2 = OpFSub %double %double_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 2: Don't fold (0,0,0,1) - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpLoad %v4double %n\n" +
+ "%2 = OpFSub %v4double %v4double_0_0_0_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 3: Fold (0,0,0,0) - n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpLoad %v4double %n\n" +
+ "%2 = OpFSub %v4double %v4double_0_0_0_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3)
+));
+// clang-format on
+
} // anonymous namespace