Add folding for redundant add/sub/mul/div/mix operations
authorArseny Kapoulkine <arseny.kapoulkine@gmail.com>
Sat, 17 Feb 2018 19:55:54 +0000 (11:55 -0800)
committerSteven Perron <stevenperron@google.com>
Tue, 20 Feb 2018 23:29:27 +0000 (18:29 -0500)
This change implements instruction folding for arithmetic operations
that are redundant, specifically:

  x + 0 = 0 + x = x
  x - 0 = x
  0 - x = -x
  x * 0 = 0 * x = 0
  x * 1 = 1 * x = x
  0 / x = 0
  x / 1 = x
  mix(a, b, 0) = a
  mix(a, b, 1) = b

Cache ExtInst import id in feature manager

This allows us to avoid string lookups during optimization; for now we
just cache GLSL std450 import id but I can imagine caching more sets as
they become utilized by the optimizer.

Add tests for add/sub/mul/div/mix folding

The tests cover scalar float/double cases, and some vector cases.

Since most of the code for floating point folding is shared, the tests
for vector folding are not as exhaustive as scalar.

To test sub->negate folding I had to implement a custom fixture.

source/opt/const_folding_rules.cpp
source/opt/feature_manager.cpp
source/opt/feature_manager.h
source/opt/folding_rules.cpp
source/opt/insert_extract_elim.cpp
source/opt/instruction.cpp
source/opt/instruction.h
source/opt/pass.h
test/opt/fold_test.cpp

index ab39ed8..6556264 100644 (file)
@@ -30,23 +30,6 @@ inline std::vector<uint32_t> ExtractInts(uint64_t a) {
   return result;
 }
 
-// Returns true if we are allowed to fold or otherwise manipulate the
-// instruction that defines |id| in the given context.
-bool CanFoldFloatingPoint(ir::IRContext* context, uint32_t id) {
-  // TODO: Add the rules for kernels.  For now it will be pessimistic.
-  if (!context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
-    return false;
-  }
-
-  bool is_nocontract = false;
-  context->get_decoration_mgr()->WhileEachDecoration(
-      id, SpvDecorationNoContraction, [&is_nocontract](const ir::Instruction&) {
-        is_nocontract = true;
-        return false;
-      });
-  return !is_nocontract;
-}
-
 // Folds an OpcompositeExtract where input is a composite constant.
 ConstantFoldingRule FoldExtractWithConstants() {
   return [](ir::Instruction* inst,
@@ -147,7 +130,7 @@ ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) {
     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
     const analysis::Vector* vector_type = result_type->AsVector();
 
-    if (!CanFoldFloatingPoint(context, inst->result_id())) {
+    if (!inst->IsFloatingPointFoldingAllowed()) {
       return nullptr;
     }
 
index 8e1bcdc..f9b91bd 100644 (file)
@@ -24,6 +24,7 @@ namespace opt {
 void FeatureManager::Analyze(ir::Module* module) {
   AddExtensions(module);
   AddCapabilities(module);
+  AddExtInstImportIds(module);
 }
 
 void FeatureManager::AddExtensions(ir::Module* module) {
@@ -56,5 +57,9 @@ void FeatureManager::AddCapabilities(ir::Module* module) {
   }
 }
 
+void FeatureManager::AddExtInstImportIds(ir::Module* module) {
+  extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450");
+}
+
 }  // namespace opt
 }  // namespace spvtools
index 9c2a05c..b99a776 100644 (file)
@@ -46,6 +46,10 @@ class FeatureManager {
     return &capabilities_;
   }
 
+  uint32_t GetExtInstImportId_GLSLstd450() const {
+    return extinst_importid_GLSLstd450_;
+  }
+
  private:
   // Analyzes |module| and records enabled extensions.
   void AddExtensions(ir::Module* module);
@@ -57,6 +61,9 @@ class FeatureManager {
   // Analyzes |module| and records enabled capabilities.
   void AddCapabilities(ir::Module* module);
 
+  // Analyzes |module| and records imported external instruction sets.
+  void AddExtInstImportIds(ir::Module* module);
+
   // Auxiliary object for querying SPIR-V grammar facts.
   const libspirv::AssemblyGrammar& grammar_;
 
@@ -65,6 +72,9 @@ class FeatureManager {
 
   // The enabled capabilities.
   libspirv::CapabilitySet capabilities_;
+
+  // Common external instruction import ids, cached for performance.
+  uint32_t extinst_importid_GLSLstd450_ = 0;
 };
 
 }  // namespace opt
index 681c070..b0f99b7 100644 (file)
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "folding_rules.h"
+#include "latest_version_glsl_std_450_header.h"
 
 namespace spvtools {
 namespace opt {
@@ -21,6 +22,10 @@ namespace {
 const uint32_t kExtractCompositeIdInIdx = 0;
 const uint32_t kInsertObjectIdInIdx = 0;
 const uint32_t kInsertCompositeIdInIdx = 1;
+const uint32_t kExtInstSetIdInIdx = 0;
+const uint32_t kExtInstInstructionInIdx = 1;
+const uint32_t kFMixXIdInIdx = 2;
+const uint32_t kFMixYIdInIdx = 3;
 
 FoldingRule IntMultipleBy1() {
   return [](ir::Instruction* inst,
@@ -326,6 +331,199 @@ FoldingRule RedundantSelect() {
     }
   };
 }
+
+enum class FloatConstantKind { Unknown, Zero, One };
+
+FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
+  if (constant == nullptr) {
+    return FloatConstantKind::Unknown;
+  }
+
+  if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
+    const std::vector<const analysis::Constant*>& components =
+        vc->GetComponents();
+    assert(!components.empty());
+
+    FloatConstantKind kind = getFloatConstantKind(components[0]);
+
+    for (size_t i = 1; i < components.size(); ++i) {
+      if (getFloatConstantKind(components[i]) != kind) {
+        return FloatConstantKind::Unknown;
+      }
+    }
+
+    return kind;
+  } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
+    double value = (fc->type()->AsFloat()->width() == 64) ? fc->GetDoubleValue()
+                                                          : fc->GetFloatValue();
+
+    if (value == 0.0) {
+      return FloatConstantKind::Zero;
+    } else if (value == 1.0) {
+      return FloatConstantKind::One;
+    } else {
+      return FloatConstantKind::Unknown;
+    }
+  } else {
+    return FloatConstantKind::Unknown;
+  }
+}
+
+FoldingRule RedundantFAdd() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFAdd && "Wrong opcode.  Should be OpFAdd.");
+    assert(constants.size() == 2);
+
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return false;
+    }
+
+    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+                            {inst->GetSingleWordInOperand(
+                                kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+FoldingRule RedundantFSub() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFSub && "Wrong opcode.  Should be OpFSub.");
+    assert(constants.size() == 2);
+
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return false;
+    }
+
+    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+    if (kind0 == FloatConstantKind::Zero) {
+      inst->SetOpcode(SpvOpFNegate);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
+      return true;
+    }
+
+    if (kind1 == FloatConstantKind::Zero) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+FoldingRule RedundantFMul() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFMul && "Wrong opcode.  Should be OpFMul.");
+    assert(constants.size() == 2);
+
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return false;
+    }
+
+    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+                            {inst->GetSingleWordInOperand(
+                                kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
+      return true;
+    }
+
+    if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
+                            {inst->GetSingleWordInOperand(
+                                kind0 == FloatConstantKind::One ? 1 : 0)}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+FoldingRule RedundantFDiv() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFDiv && "Wrong opcode.  Should be OpFDiv.");
+    assert(constants.size() == 2);
+
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return false;
+    }
+
+    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
+    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
+
+    if (kind0 == FloatConstantKind::Zero) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+      return true;
+    }
+
+    if (kind1 == FloatConstantKind::One) {
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+FoldingRule RedundantFMix() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpExtInst &&
+           "Wrong opcode.  Should be OpExtInst.");
+
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return false;
+    }
+
+    uint32_t instSetId =
+        inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+    if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
+        inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
+            GLSLstd450FMix) {
+      assert(constants.size() == 5);
+
+      FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
+
+      if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
+        inst->SetOpcode(SpvOpCopyObject);
+        inst->SetInOperands(
+            {{SPV_OPERAND_TYPE_ID,
+              {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
+                                                ? kFMixXIdInIdx
+                                                : kFMixYIdInIdx)}}});
+        return true;
+      }
+    }
+
+    return false;
+  };
+}
+
 }  // namespace
 
 spvtools::opt::FoldingRules::FoldingRules() {
@@ -339,11 +537,19 @@ spvtools::opt::FoldingRules::FoldingRules() {
   rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
   rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
 
+  rules_[SpvOpExtInst].push_back(RedundantFMix());
+
+  rules_[SpvOpFAdd].push_back(RedundantFAdd());
+  rules_[SpvOpFDiv].push_back(RedundantFDiv());
+  rules_[SpvOpFMul].push_back(RedundantFMul());
+  rules_[SpvOpFSub].push_back(RedundantFSub());
+
   rules_[SpvOpIMul].push_back(IntMultipleBy1());
 
   rules_[SpvOpPhi].push_back(RedundantPhi());
 
   rules_[SpvOpSelect].push_back(RedundantSelect());
 }
+
 }  // namespace opt
 }  // namespace spvtools
index d56ca19..a49acca 100644 (file)
@@ -96,7 +96,7 @@ uint32_t InsertExtractElimPass::DoExtract(ir::Instruction* compInst,
       }
     } else if (cinst->opcode() == SpvOpExtInst &&
                cinst->GetSingleWordInOperand(kExtInstSetIdInIdx) ==
-                   get_module()->GetExtInstImportId("GLSL.std.450") &&
+                   get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
                cinst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
                    GLSLstd450FMix) {
       // If mixing value component is 0 or 1 we just match with x or y.
index 9510ee9..f1e483b 100644 (file)
@@ -483,6 +483,22 @@ bool Instruction::IsFoldableByFoldScalar() const {
   return opt::IsFoldableType(type);
 }
 
+bool Instruction::IsFloatingPointFoldingAllowed() const {
+  // TODO: Add the rules for kernels.  For now it will be pessimistic.
+  if (!context_->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
+    return false;
+  }
+
+  bool is_nocontract = false;
+  context_->get_decoration_mgr()->WhileEachDecoration(
+      opcode_, SpvDecorationNoContraction,
+      [&is_nocontract](const ir::Instruction&) {
+        is_nocontract = true;
+        return false;
+      });
+  return !is_nocontract;
+}
+
 std::string Instruction::PrettyPrint(uint32_t options) const {
   // Convert the module to binary.
   std::vector<uint32_t> module_binary;
index fccd4c4..52f58ca 100644 (file)
@@ -372,6 +372,11 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // constant value by |FoldScalar|.
   bool IsFoldableByFoldScalar() const;
 
+  // Returns true if we are allowed to fold or otherwise manipulate the
+  // instruction that defines |id| in the given context. This includes not
+  // handling NaN values.
+  bool IsFloatingPointFoldingAllowed() const;
+
   inline bool operator==(const Instruction&) const;
   inline bool operator!=(const Instruction&) const;
   inline bool operator<(const Instruction&) const;
index 3733ba5..226968a 100644 (file)
@@ -82,6 +82,10 @@ class Pass {
     return context()->get_decoration_mgr();
   }
 
+  FeatureManager* get_feature_mgr() const {
+    return context()->get_feature_mgr();
+  }
+
   // Returns a pointer to the current module for this pass.
   ir::Module* get_module() const { return context_->module(); }
 
index 5c2053d..8cb7dcd 100644 (file)
@@ -80,6 +80,10 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
 #define TRUE_ID 101
 #define VEC2_0_ID 102
 #define INT_7_ID 103
+#define FLOAT_0_ID 104
+#define DOUBLE_0_ID 105
+#define VEC4_0_ID 106
+#define DVEC4_0_ID 106
 const std::string& Header() {
   static const std::string header = R"(OpCapability Shader
 %1 = OpExtInstImport "GLSL.std.450"
@@ -103,10 +107,16 @@ OpName %main "main"
 %uint = OpTypeInt 32 1
 %v2int = OpTypeVector %int 2
 %v4int = OpTypeVector %int 4
+%v4float = OpTypeVector %float 4
+%v4double = OpTypeVector %double 4
 %struct_v2int_int_int = OpTypeStruct %v2int %int %int
 %_ptr_int = OpTypePointer Function %int
 %_ptr_uint = OpTypePointer Function %uint
 %_ptr_bool = OpTypePointer Function %bool
+%_ptr_float = OpTypePointer Function %float
+%_ptr_double = OpTypePointer Function %double
+%_ptr_v4float = OpTypePointer Function %v4float
+%_ptr_v4double = OpTypePointer Function %v4double
 %_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
 %short_0 = OpConstant %short 0
 %short_3 = OpConstant %short 3
@@ -132,17 +142,27 @@ OpName %main "main"
 %float16_1 = OpConstant %float16 1
 %float16_2 = OpConstant %float16 2
 %float_n1 = OpConstant %float -1
+%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
 %float_0 = OpConstant %float 0
 %float_1 = OpConstant %float 1
 %float_2 = OpConstant %float 2
 %float_3 = OpConstant %float 3
 %double_n1 = OpConstant %double -1
+%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
 %double_0 = OpConstant %double 0
 %double_1 = OpConstant %double 1
 %double_2 = OpConstant %double 2
 %double_3 = OpConstant %double 3
 %float_nan = OpConstant %float -0x1.8p+128
 %double_nan = OpConstant %double -0x1.8p+1024
+%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
+%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
+%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
+%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
+%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
 )";
 
   return header;
@@ -2211,5 +2231,537 @@ INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest,
           "OpFunctionEnd",
       2, INT_0_ID)
 ));
+
+INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold n + 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFAdd %float %3 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Don't fold n - 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFSub %float %3 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 2: Don't fold n * 2.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFMul %float %3 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 3: Don't fold n / 2.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFDiv %float %3 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 4: Fold n + 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFAdd %float %3 %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 5: Fold 0.0 + n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFAdd %float %float_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 6: Fold n - 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFSub %float %3 %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 7: Fold n * 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFMul %float %3 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 8: Fold 1.0 * n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFMul %float %float_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 9: Fold n / 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFDiv %float %3 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 10: Fold n * 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFMul %float %3 %104\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, FLOAT_0_ID),
+    // Test case 11: Fold 0.0 * n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFMul %float %104 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, FLOAT_0_ID),
+    // Test case 12: Fold 0.0 / n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFDiv %float %104 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, FLOAT_0_ID),
+    // Test case 13: Don't fold mix(a, b, 2.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_float Function\n" +
+            "%b = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %a\n" +
+            "%4 = OpLoad %float %b\n" +
+            "%2 = OpExtInst %float %1 FMix %3 %4 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 14: Fold mix(a, b, 0.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_float Function\n" +
+            "%b = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %a\n" +
+            "%4 = OpLoad %float %b\n" +
+            "%2 = OpExtInst %float %1 FMix %3 %4 %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 15: Fold mix(a, b, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_float Function\n" +
+            "%b = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %a\n" +
+            "%4 = OpLoad %float %b\n" +
+            "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 4)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold n + 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFAdd %double %3 %double_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Don't fold n - 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFSub %double %3 %double_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 2: Don't fold n * 2.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFMul %double %3 %double_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 3: Don't fold n / 2.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFDiv %double %3 %double_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 4: Fold n + 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFAdd %double %3 %double_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 5: Fold 0.0 + n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFAdd %double %double_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 6: Fold n - 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFSub %double %3 %double_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 7: Fold n * 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFMul %double %3 %double_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 8: Fold 1.0 * n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFMul %double %double_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 9: Fold n / 1.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFDiv %double %3 %double_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 10: Fold n * 0.0
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFMul %double %3 %105\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, DOUBLE_0_ID),
+    // Test case 11: Fold 0.0 * n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFMul %double %105 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, DOUBLE_0_ID),
+    // Test case 12: Fold 0.0 / n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFDiv %double %105 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, DOUBLE_0_ID),
+    // Test case 13: Don't fold mix(a, b, 2.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_double Function\n" +
+            "%b = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %a\n" +
+            "%4 = OpLoad %double %b\n" +
+            "%2 = OpExtInst %double %1 FMix %3 %4 %double_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 14: Fold mix(a, b, 0.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_double Function\n" +
+            "%b = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %a\n" +
+            "%4 = OpLoad %double %b\n" +
+            "%2 = OpExtInst %double %1 FMix %3 %4 %double_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+    // Test case 15: Fold mix(a, b, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_double Function\n" +
+            "%b = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %a\n" +
+            "%4 = OpLoad %double %b\n" +
+            "%2 = OpExtInst %double %1 FMix %3 %4 %double_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 4)
+));
+
+INSTANTIATE_TEST_CASE_P(FloatVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4float Function\n" +
+            "%3 = OpLoad %v4float %n\n" +
+            "%2 = OpFMul %v4float %3 %v4float_0_0_0_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4float Function\n" +
+            "%3 = OpLoad %v4float %n\n" +
+            "%2 = OpFMul %v4float %3 %106\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, VEC4_0_ID),
+    // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4float Function\n" +
+            "%3 = OpLoad %v4float %n\n" +
+            "%2 = OpFMul %v4float %3 %v4float_1_1_1_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4double Function\n" +
+            "%3 = OpLoad %v4double %n\n" +
+            "%2 = OpFMul %v4double %3 %v4double_0_0_0_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4double Function\n" +
+            "%3 = OpLoad %v4double %n\n" +
+            "%2 = OpFMul %v4double %3 %106\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, DVEC4_0_ID),
+    // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4double Function\n" +
+            "%3 = OpLoad %v4double %n\n" +
+            "%2 = OpFMul %v4double %3 %v4double_1_1_1_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3)
+));
+// clang-format on
+
+using ToNegateFoldingTest =
+    ::testing::TestWithParam<InstructionFoldingCase<uint32_t>>;
+
+TEST_P(ToNegateFoldingTest, Case) {
+  const auto& tc = GetParam();
+
+  // Build module.
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  // Fold the instruction to test.
+  opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+  std::unique_ptr<ir::Instruction> original_inst(inst->Clone(context.get()));
+  bool succeeded = opt::FoldInstruction(inst);
+
+  // Make sure the instruction folded as expected.
+  EXPECT_EQ(inst->result_id(), original_inst->result_id());
+  EXPECT_EQ(inst->type_id(), original_inst->type_id());
+  EXPECT_TRUE((!succeeded) == (tc.expected_result == 0));
+  if (succeeded) {
+    EXPECT_EQ(inst->opcode(), SpvOpFNegate);
+    EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result);
+  } else {
+    EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands());
+    for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
+      EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i));
+    }
+  }
+}
+
 // clang-format off
+INSTANTIATE_TEST_CASE_P(FloatRedundantSubFoldingTest, ToNegateFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold 1.0 - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFSub %float %float_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Fold 0.0 - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_float Function\n" +
+            "%3 = OpLoad %float %n\n" +
+            "%2 = OpFSub %float %float_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+       // Test case 2: Don't fold (0,0,0,1) - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4float Function\n" +
+            "%3 = OpLoad %v4float %n\n" +
+            "%2 = OpFSub %v4float %v4float_0_0_0_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+       // Test case 3: Fold (0,0,0,0) - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4float Function\n" +
+            "%3 = OpLoad %v4float %n\n" +
+            "%2 = OpFSub %v4float %v4float_0_0_0_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3)
+));
+
+INSTANTIATE_TEST_CASE_P(DoubleRedundantSubFoldingTest, ToNegateFoldingTest,
+                        ::testing::Values(
+    // Test case 0: Don't fold 1.0 - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFSub %double %double_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+    // Test case 1: Fold 0.0 - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_double Function\n" +
+            "%3 = OpLoad %double %n\n" +
+            "%2 = OpFSub %double %double_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3),
+       // Test case 2: Don't fold (0,0,0,1) - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4double Function\n" +
+            "%3 = OpLoad %v4double %n\n" +
+            "%2 = OpFSub %v4double %v4double_0_0_0_1 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0),
+       // Test case 3: Fold (0,0,0,0) - n
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%n = OpVariable %_ptr_v4double Function\n" +
+            "%3 = OpLoad %v4double %n\n" +
+            "%2 = OpFSub %v4double %v4double_0_0_0_0 %3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3)
+));
+// clang-format on
+
 }  // anonymous namespace