Merge arithmetic with non-trivial constant operands
authorAlan Baker <alanbaker@google.com>
Fri, 16 Feb 2018 21:07:33 +0000 (16:07 -0500)
committerDavid Neto <dneto@google.com>
Tue, 27 Feb 2018 18:02:13 +0000 (13:02 -0500)
Adding basis of arithmetic merging

* Refactored constant collection in ConstantManager
* New rules:
 * consecutive negates
 * negate of arithmetic op with a constant
 * consecutive muls
 * reciprocal of div

* Removed IRContext::CanFoldFloatingPoint
 * replaced by Instruction::IsFloatingPointFoldingAllowed
* Fixed some bad tests
* added some header comments

Added PerformIntegerOperation

* minor fixes to constants and tests
* fixed IntMultiplyBy1 to work with 64 bit ints
* added tests for integer mul merging

Adding test for vector integer multiply merging

Adding support for merging integer add and sub through negate

* Added tests

Adding rules to merge mult with preceding divide

* Has a couple tests, but needs more
* Added more comments

Fixed bug in integer division folding

* Will no longer merge through integer division if there would be a
remainder in the division
* Added a bunch more tests

Adding rules to merge divide and multiply through divide

* Improved comments
* Added tests

Adding rules to handle mul or div of a negation

* Added tests

Changes for review

* Early exit if no constants are involved in more functions
* fixed some comments
* removed unused declaration
* clarified some logic

Adding new rules for add and subtract

* Fold adds of adds, subtracts or negates
* Fold subtracts of adds, subtracts or negates
* Added tests

source/opt/const_folding_rules.cpp
source/opt/constants.cpp
source/opt/constants.h
source/opt/fold.cpp
source/opt/folding_rules.cpp
source/util/hex_float.h
test/opt/fold_test.cpp

index 6556264..6e82874 100644 (file)
@@ -20,16 +20,6 @@ namespace opt {
 namespace {
 const uint32_t kExtractCompositeIdInIdx = 0;
 
-// Returns a vector that contains the two 32-bit integers that result from
-// splitting |a| in two.  The first entry in vector are the low order bit if
-// |a|.
-inline std::vector<uint32_t> ExtractInts(uint64_t a) {
-  std::vector<uint32_t> result;
-  result.push_back(static_cast<uint32_t>(a));
-  result.push_back(static_cast<uint32_t>(a >> 32));
-  return result;
-}
-
 // Folds an OpcompositeExtract where input is a composite constant.
 ConstantFoldingRule FoldExtractWithConstants() {
   return [](ir::Instruction* inst,
@@ -168,34 +158,6 @@ ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) {
   };
 }
 
-// Returns the floating point value of |c|.  The constant |c| must have type
-// |Float|, and width |32|.
-float GetFloatFromConst(const analysis::Constant* c) {
-  assert(c->type()->AsFloat() != nullptr &&
-         c->type()->AsFloat()->width() == 32);
-  const analysis::FloatConstant* fc = c->AsFloatConstant();
-  if (fc) {
-    return fc->GetFloatValue();
-  } else {
-    assert(c->AsNullConstant() && "c must be a float point constant.");
-    return 0.0f;
-  }
-}
-
-// Returns the double value of |c|.  The constant |c| must have type
-// |Float|, and width |64|.
-double GetDoubleFromConst(const analysis::Constant* c) {
-  assert(c->type()->AsFloat() != nullptr &&
-         c->type()->AsFloat()->width() == 64);
-  const analysis::FloatConstant* fc = c->AsFloatConstant();
-  if (fc) {
-    return fc->GetDoubleValue();
-  } else {
-    assert(c->AsNullConstant() && "c must be a float point constant.");
-    return 0.0;
-  }
-}
-
 // This macro defines a |FloatScalarFoldingRule| that applies |op|.  The
 // operator |op| must work for both float and double, and use syntax "f1 op f2".
 #define FOLD_FPARITH_OP(op)                                               \
@@ -207,16 +169,16 @@ double GetDoubleFromConst(const analysis::Constant* c) {
     const analysis::Float* float_type = result_type->AsFloat();           \
     assert(float_type != nullptr);                                        \
     if (float_type->width() == 32) {                                      \
-      float fa = GetFloatFromConst(a);                                    \
-      float fb = GetFloatFromConst(b);                                    \
+      float fa = a->GetFloat();                                           \
+      float fb = b->GetFloat();                                           \
       spvutils::FloatProxy<float> result(fa op fb);                       \
-      std::vector<uint32_t> words = {result.data()};                      \
+      std::vector<uint32_t> words = result.GetWords();                    \
       return const_mgr->GetConstant(result_type, words);                  \
     } else if (float_type->width() == 64) {                               \
-      double fa = GetDoubleFromConst(a);                                  \
-      double fb = GetDoubleFromConst(b);                                  \
+      double fa = a->GetDouble();                                         \
+      double fb = b->GetDouble();                                         \
       spvutils::FloatProxy<double> result(fa op fb);                      \
-      std::vector<uint32_t> words(ExtractInts(result.data()));            \
+      std::vector<uint32_t> words = result.GetWords();                    \
       return const_mgr->GetConstant(result_type, words);                  \
     }                                                                     \
     return nullptr;                                                       \
@@ -260,15 +222,15 @@ bool CompareFloatingPoint(bool op_result, bool op_unordered,
     const analysis::Float* float_type = a->type()->AsFloat();             \
     assert(float_type != nullptr);                                        \
     if (float_type->width() == 32) {                                      \
-      float fa = GetFloatFromConst(a);                                    \
-      float fb = GetFloatFromConst(b);                                    \
+      float fa = a->GetFloat();                                           \
+      float fb = b->GetFloat();                                           \
       bool result = CompareFloatingPoint(                                 \
           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
       std::vector<uint32_t> words = {uint32_t(result)};                   \
       return const_mgr->GetConstant(result_type, words);                  \
     } else if (float_type->width() == 64) {                               \
-      double fa = GetDoubleFromConst(a);                                  \
-      double fb = GetDoubleFromConst(b);                                  \
+      double fa = a->GetDouble();                                         \
+      double fb = b->GetDouble();                                         \
       bool result = CompareFloatingPoint(                                 \
           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
       std::vector<uint32_t> words = {uint32_t(result)};                   \
index d5b873f..1eb9efe 100644 (file)
@@ -22,6 +22,76 @@ namespace spvtools {
 namespace opt {
 namespace analysis {
 
+float Constant::GetFloat() const {
+  assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
+
+  if (const FloatConstant* fc = AsFloatConstant()) {
+    return fc->GetFloatValue();
+  } else {
+    assert(AsNullConstant() && "Must be a floating point constant.");
+    return 0.0f;
+  }
+}
+
+double Constant::GetDouble() const {
+  assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
+
+  if (const FloatConstant* fc = AsFloatConstant()) {
+    return fc->GetDoubleValue();
+  } else {
+    assert(AsNullConstant() && "Must be a floating point constant.");
+    return 0.0;
+  }
+}
+
+uint32_t Constant::GetU32() const {
+  assert(type()->AsInteger() != nullptr);
+  assert(type()->AsInteger()->width() == 32);
+
+  if (const IntConstant* ic = AsIntConstant()) {
+    return ic->GetU32BitValue();
+  } else {
+    assert(AsNullConstant() && "Must be an integer constant.");
+    return 0u;
+  }
+}
+
+uint64_t Constant::GetU64() const {
+  assert(type()->AsInteger() != nullptr);
+  assert(type()->AsInteger()->width() == 64);
+
+  if (const IntConstant* ic = AsIntConstant()) {
+    return ic->GetU64BitValue();
+  } else {
+    assert(AsNullConstant() && "Must be an integer constant.");
+    return 0u;
+  }
+}
+
+int32_t Constant::GetS32() const {
+  assert(type()->AsInteger() != nullptr);
+  assert(type()->AsInteger()->width() == 32);
+
+  if (const IntConstant* ic = AsIntConstant()) {
+    return ic->GetS32BitValue();
+  } else {
+    assert(AsNullConstant() && "Must be an integer constant.");
+    return 0;
+  }
+}
+
+int64_t Constant::GetS64() const {
+  assert(type()->AsInteger() != nullptr);
+  assert(type()->AsInteger()->width() == 64);
+
+  if (const IntConstant* ic = AsIntConstant()) {
+    return ic->GetS64BitValue();
+  } else {
+    assert(AsNullConstant() && "Must be an integer constant.");
+    return 0;
+  }
+}
+
 ConstantManager::ConstantManager(ir::IRContext* ctx) : ctx_(ctx) {
   // Populate the constant table with values from constant declarations in the
   // module.  The values of each OpConstant declaration is the identity
@@ -35,6 +105,22 @@ Type* ConstantManager::GetType(const ir::Instruction* inst) const {
   return context()->get_type_mgr()->GetType(inst->type_id());
 }
 
+std::vector<const Constant*> ConstantManager::GetOperandConstants(
+    ir::Instruction* inst) const {
+  std::vector<const Constant*> constants;
+  for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
+    const ir::Operand* operand = &inst->GetInOperand(i);
+    if (operand->type != SPV_OPERAND_TYPE_ID) {
+      constants.push_back(nullptr);
+    } else {
+      uint32_t id = operand->words[0];
+      const analysis::Constant* constant = FindDeclaredConstant(id);
+      constants.push_back(constant);
+    }
+  }
+  return constants;
+}
+
 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
     const std::vector<uint32_t>& ids) const {
   std::vector<const Constant*> constants;
index 6ac9bef..cd3134b 100644 (file)
@@ -83,6 +83,30 @@ class Constant {
   virtual const ArrayConstant* AsArrayConstant() const { return nullptr; }
   virtual const NullConstant* AsNullConstant() const { return nullptr; }
 
+  // Returns the float representation of the constant. Must be a 32 bit
+  // Float type.
+  float GetFloat() const;
+
+  // Returns the double representation of the constant. Must be a 64 bit
+  // Float type.
+  double GetDouble() const;
+
+  // Returns uint32_t representation of the constant. Must be a 32 bit
+  // Integer type.
+  uint32_t GetU32() const;
+
+  // Returns uint64_t representation of the constant. Must be a 64 bit
+  // Integer type.
+  uint64_t GetU64() const;
+
+  // Returns int32_t representation of the constant. Must be a 32 bit
+  // Integer type.
+  int32_t GetS32() const;
+
+  // Returns int64_t representation of the constant. Must be a 64 bit
+  // Integer type.
+  int64_t GetS64() const;
+
   const Type* type() const { return type_; }
 
  protected:
@@ -135,6 +159,22 @@ class IntConstant : public ScalarConstant {
     return words()[0];
   }
 
+  int64_t GetS64BitValue() const {
+    // Relies on unsigned values smaller than 64-bit being sign extended.  See
+    // section 2.2.1 of the SPIR-V spec.
+    assert(words().size() == 2);
+    return static_cast<uint64_t>(words()[1]) << 32 |
+           static_cast<uint64_t>(words()[0]);
+  }
+
+  uint64_t GetU64BitValue() const {
+    // Relies on unsigned values smaller than 64-bit being zero extended.  See
+    // section 2.2.1 of the SPIR-V spec.
+    assert(words().size() == 2);
+    return static_cast<uint64_t>(words()[1]) << 32 |
+           static_cast<uint64_t>(words()[0]);
+  }
+
   bool IsZero() const {
     bool is_zero = true;
     for (uint32_t v : words()) {
@@ -507,6 +547,10 @@ class ConstantManager {
   std::vector<const Constant*> GetConstantsFromIds(
       const std::vector<uint32_t>& ids) const;
 
+  // Returns a vector of constants representing each in operand. If an operand
+  // is not constant its entry is nullptr.
+  std::vector<const Constant*> GetOperandConstants(ir::Instruction* inst) const;
+
   // Records a mapping between |inst| and the constant value generated by it.
   // It returns true if a new Constant was successfully mapped, false if |inst|
   // generates no constant values.
index 6cc486a..678c456 100644 (file)
@@ -194,20 +194,10 @@ bool FoldInstructionInternal(ir::Instruction* inst) {
   }
 
   SpvOp opcode = inst->opcode();
-  analysis::ConstantManager* const_manger = context->get_constant_mgr();
+  analysis::ConstantManager* const_manager = context->get_constant_mgr();
 
-  std::vector<const analysis::Constant*> constants;
-  for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
-    const ir::Operand* operand = &inst->GetInOperand(i);
-    if (operand->type != SPV_OPERAND_TYPE_ID) {
-      constants.push_back(nullptr);
-    } else {
-      uint32_t id = operand->words[0];
-      const analysis::Constant* constant =
-          const_manger->FindDeclaredConstant(id);
-      constants.push_back(constant);
-    }
-  }
+  std::vector<const analysis::Constant*> constants =
+      const_manager->GetOperandConstants(inst);
 
   static FoldingRules* rules = new FoldingRules();
   for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) {
@@ -466,7 +456,7 @@ bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst,
   }
 
   switch (opcode) {
-      // Logical
+    // Logical
     case SpvOp::SpvOpLogicalOr:
       for (uint32_t i = 0; i < 2; i++) {
         if (constants[i] != nullptr) {
index b0f99b7..4f2c128 100644 (file)
@@ -27,6 +27,1189 @@ const uint32_t kExtInstInstructionInIdx = 1;
 const uint32_t kFMixXIdInIdx = 2;
 const uint32_t kFMixYIdInIdx = 3;
 
+// Returns the element width of |type|.
+uint32_t ElementWidth(const analysis::Type* type) {
+  if (const analysis::Vector* vec_type = type->AsVector()) {
+    return ElementWidth(vec_type->element_type());
+  } else if (const analysis::Float* float_type = type->AsFloat()) {
+    return float_type->width();
+  } else {
+    assert(type->AsInteger());
+    return type->AsInteger()->width();
+  }
+}
+
+// Returns true if |type| is Float or a vector of Float.
+bool HasFloatingPoint(const analysis::Type* type) {
+  if (type->AsFloat()) {
+    return true;
+  } else if (const analysis::Vector* vec_type = type->AsVector()) {
+    return vec_type->element_type()->AsFloat() != nullptr;
+  }
+
+  return false;
+}
+
+// Returns false if |val| is NaN, infinite or subnormal.
+template <typename T>
+bool IsValidResult(T val) {
+  int classified = std::fpclassify(val);
+  switch (classified) {
+    case FP_NAN:
+    case FP_INFINITE:
+    case FP_SUBNORMAL:
+      return false;
+    default:
+      return true;
+  }
+}
+
+const analysis::Constant* ConstInput(
+    const std::vector<const analysis::Constant*>& constants) {
+  return constants[0] ? constants[0] : constants[1];
+}
+
+ir::Instruction* NonConstInput(ir::IRContext* context,
+                               const analysis::Constant* c,
+                               ir::Instruction* inst) {
+  uint32_t in_op = c ? 1u : 0u;
+  return context->get_def_use_mgr()->GetDef(
+      inst->GetSingleWordInOperand(in_op));
+}
+
+// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
+// constant.
+uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
+                                     const analysis::Constant* c) {
+  assert(c);
+  assert(c->type()->AsFloat());
+  uint32_t width = c->type()->AsFloat()->width();
+  assert(width == 32 || width == 64);
+  std::vector<uint32_t> words;
+  if (width == 64) {
+    spvutils::FloatProxy<double> result(c->GetDouble() * -1.0);
+    words = result.GetWords();
+  } else {
+    spvutils::FloatProxy<float> result(c->GetFloat() * -1.0f);
+    words = result.GetWords();
+  }
+
+  const analysis::Constant* negated_const =
+      const_mgr->GetConstant(c->type(), std::move(words));
+  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+std::vector<uint32_t> ExtractInts(uint64_t val) {
+  std::vector<uint32_t> words;
+  words.push_back(static_cast<uint32_t>(val));
+  words.push_back(static_cast<uint32_t>(val >> 32));
+  return words;
+}
+
+// Negates the integer constant |c|. Returns the id of the defining instruction.
+uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
+                               const analysis::Constant* c) {
+  assert(c);
+  assert(c->type()->AsInteger());
+  uint32_t width = c->type()->AsInteger()->width();
+  assert(width == 32 || width == 64);
+  std::vector<uint32_t> words;
+  if (width == 64) {
+    uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
+    words = ExtractInts(uval);
+  } else {
+    words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
+  }
+
+  const analysis::Constant* negated_const =
+      const_mgr->GetConstant(c->type(), std::move(words));
+  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+// Negates the vector constant |c|. Returns the id of the defining instruction.
+uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
+                              const analysis::Constant* c) {
+  assert(const_mgr && c);
+  assert(c->type()->AsVector());
+  if (c->AsNullConstant()) {
+    // 0.0 vs -0.0 shouldn't matter.
+    return const_mgr->GetDefiningInstruction(c)->result_id();
+  } else {
+    const analysis::Type* component_type =
+        c->AsVectorConstant()->component_type();
+    std::vector<uint32_t> words;
+    for (auto& comp : c->AsVectorConstant()->GetComponents()) {
+      if (component_type->AsFloat()) {
+        words.push_back(NegateFloatingPointConstant(const_mgr, comp));
+      } else {
+        assert(component_type->AsInteger());
+        words.push_back(NegateIntegerConstant(const_mgr, comp));
+      }
+    }
+
+    const analysis::Constant* negated_const =
+        const_mgr->GetConstant(c->type(), std::move(words));
+    return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+  }
+}
+
+// Negates |c|. Returns the id of the defining instruction.
+uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
+                        const analysis::Constant* c) {
+  if (c->type()->AsVector()) {
+    return NegateVectorConstant(const_mgr, c);
+  } else if (c->type()->AsFloat()) {
+    return NegateFloatingPointConstant(const_mgr, c);
+  } else {
+    assert(c->type()->AsInteger());
+    return NegateIntegerConstant(const_mgr, c);
+  }
+}
+
+// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
+// Returns 0 if the reciprocal is NaN, infinite or subnormal.
+uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
+                    const analysis::Constant* c) {
+  assert(const_mgr && c);
+  assert(c->type()->AsFloat());
+
+  uint32_t width = c->type()->AsFloat()->width();
+  assert(width == 32 || width == 64);
+  std::vector<uint32_t> words;
+  if (width == 64) {
+    spvutils::FloatProxy<double> result(1.0 / c->GetDouble());
+    if (!IsValidResult(result.getAsFloat())) return 0;
+    words = result.GetWords();
+  } else {
+    spvutils::FloatProxy<float> result(1.0f / c->GetFloat());
+    if (!IsValidResult(result.getAsFloat())) return 0;
+    words = result.GetWords();
+  }
+
+  const analysis::Constant* negated_const =
+      const_mgr->GetConstant(c->type(), std::move(words));
+  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+// Replaces fdiv where second operand is constant with fmul.
+FoldingRule ReciprocalFDiv() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFDiv);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (!inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    if (constants[1] != nullptr) {
+      uint32_t id = 0;
+      if (const analysis::VectorConstant* vector_const =
+              constants[1]->AsVectorConstant()) {
+        std::vector<uint32_t> neg_ids;
+        for (auto& comp : vector_const->GetComponents()) {
+          id = Reciprocal(const_mgr, comp);
+          if (id == 0) return false;
+          neg_ids.push_back(id);
+        }
+        const analysis::Constant* negated_const =
+            const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
+        id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
+      } else {
+        id = Reciprocal(const_mgr, constants[1]);
+        if (id == 0) return false;
+      }
+      inst->SetOpcode(SpvOpFMul);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
+           {SPV_OPERAND_TYPE_ID, {id}}});
+      return true;
+    }
+
+    return false;
+  };
+};
+
+// Elides consecutive negate instructions.
+FoldingRule MergeNegateArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+    (void)constants;
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    ir::Instruction* op_inst =
+        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (op_inst->opcode() == inst->opcode()) {
+      // Elide negates.
+      inst->SetOpcode(SpvOpCopyObject);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Merges negate into a mul or div operation if that operation contains a
+// constant operand.
+// Cases:
+// -(x * 2) = x * -2
+// -(2 * x) = x * -2
+// -(x / 2) = x / -2
+// -(2 / x) = -2 / x
+FoldingRule MergeNegateMulDivArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+    (void)constants;
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    ir::Instruction* op_inst =
+        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    SpvOp opcode = op_inst->opcode();
+    if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
+        opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
+      std::vector<const analysis::Constant*> op_constants =
+          const_mgr->GetOperandConstants(op_inst);
+      // Merge negate into mul or div if one operand is constant.
+      if (op_constants[0] || op_constants[1]) {
+        bool zero_is_variable = op_constants[0] == nullptr;
+        const analysis::Constant* c = ConstInput(op_constants);
+        uint32_t neg_id = NegateConstant(const_mgr, c);
+        uint32_t non_const_id = zero_is_variable
+                                    ? op_inst->GetSingleWordInOperand(0u)
+                                    : op_inst->GetSingleWordInOperand(1u);
+        // Change this instruction to a mul/div.
+        inst->SetOpcode(op_inst->opcode());
+        if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
+          uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
+          uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
+          inst->SetInOperands(
+              {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
+        } else {
+          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+                               {SPV_OPERAND_TYPE_ID, {neg_id}}});
+        }
+        return true;
+      }
+    }
+
+    return false;
+  };
+}
+
+// Merges negate into a add or sub operation if that operation contains a
+// constant operand.
+// Cases:
+// -(x + 2) = -2 - x
+// -(2 + x) = -2 - x
+// -(x - 2) = 2 - x
+// -(2 - x) = x - 2
+FoldingRule MergeNegateAddSubArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+    (void)constants;
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    ir::Instruction* op_inst =
+        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
+        op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
+      std::vector<const analysis::Constant*> op_constants =
+          const_mgr->GetOperandConstants(op_inst);
+      if (op_constants[0] || op_constants[1]) {
+        bool zero_is_variable = op_constants[0] == nullptr;
+        bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
+                      (op_inst->opcode() == SpvOpIAdd);
+        bool swap_operands = !is_add || zero_is_variable;
+        bool negate_const = is_add;
+        const analysis::Constant* c = ConstInput(op_constants);
+        uint32_t const_id = 0;
+        if (negate_const) {
+          const_id = NegateConstant(const_mgr, c);
+        } else {
+          const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
+                                      : op_inst->GetSingleWordInOperand(0u);
+        }
+
+        // Swap operands if necessary and make the instruction a subtraction.
+        uint32_t op0 =
+            zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
+        uint32_t op1 =
+            zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
+        if (swap_operands) std::swap(op0, op1);
+        inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
+        inst->SetInOperands(
+            {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
+        return true;
+      }
+    }
+
+    return false;
+  };
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Float.
+uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
+                                       SpvOp opcode,
+                                       const analysis::Constant* input1,
+                                       const analysis::Constant* input2) {
+  const analysis::Type* type = input1->type();
+  assert(type->AsFloat());
+  uint32_t width = type->AsFloat()->width();
+  assert(width == 32 || width == 64);
+  std::vector<uint32_t> words;
+#define FOLD_OP(op)                                 \
+  if (width == 64) {                                \
+    spvutils::FloatProxy<double> val =              \
+        input1->GetDouble() op input2->GetDouble(); \
+    double dval = val.getAsFloat();                 \
+    if (!IsValidResult(dval)) return 0;             \
+    words = val.GetWords();                         \
+  } else {                                          \
+    spvutils::FloatProxy<float> val =               \
+        input1->GetFloat() op input2->GetFloat();   \
+    float fval = val.getAsFloat();                  \
+    if (!IsValidResult(fval)) return 0;             \
+    words = val.GetWords();                         \
+  }
+  switch (opcode) {
+    case SpvOpFMul:
+      FOLD_OP(*);
+      break;
+    case SpvOpFDiv:
+      FOLD_OP(/);
+      break;
+    case SpvOpFAdd:
+      FOLD_OP(+);
+      break;
+    case SpvOpFSub:
+      FOLD_OP(-);
+      break;
+    default:
+      assert(false && "Unexpected operation");
+      break;
+  }
+#undef FOLD_OP
+  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
+  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Integers.
+uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
+                                 SpvOp opcode, const analysis::Constant* input1,
+                                 const analysis::Constant* input2) {
+  assert(input1->type()->AsInteger());
+  const analysis::Integer* type = input1->type()->AsInteger();
+  uint32_t width = type->AsInteger()->width();
+  assert(width == 32 || width == 64);
+  std::vector<uint32_t> words;
+#define FOLD_OP(op)                                        \
+  if (width == 64) {                                       \
+    if (type->IsSigned()) {                                \
+      int64_t val = input1->GetS64() op input2->GetS64();  \
+      words = ExtractInts(static_cast<uint64_t>(val));     \
+    } else {                                               \
+      uint64_t val = input1->GetU64() op input2->GetU64(); \
+      words = ExtractInts(val);                            \
+    }                                                      \
+  } else {                                                 \
+    if (type->IsSigned()) {                                \
+      int32_t val = input1->GetS32() op input2->GetS32();  \
+      words.push_back(static_cast<uint32_t>(val));         \
+    } else {                                               \
+      uint32_t val = input1->GetU32() op input2->GetU32(); \
+      words.push_back(val);                                \
+    }                                                      \
+  }
+  switch (opcode) {
+    case SpvOpIMul:
+      FOLD_OP(*);
+      break;
+    case SpvOpSDiv:
+    case SpvOpUDiv:
+      // To avoid losing precision we won't perform division that would result
+      // in a remainder. Unfortunate code duplication results.
+      if (input2->AsIntConstant()->IsZero()) return 0;
+      if (width == 64) {
+        if (type->IsSigned()) {
+          if (input1->GetS64() % input2->GetS64() != 0) return 0;
+          int64_t val = input1->GetS64() / input2->GetS64();
+          words = ExtractInts(static_cast<uint64_t>(val));
+        } else {
+          if (input1->GetU64() % input2->GetU64() != 0) return 0;
+          uint64_t val = input1->GetU64() / input2->GetU64();
+          words = ExtractInts(val);
+        }
+      } else {
+        if (type->IsSigned()) {
+          if (input1->GetS32() % input2->GetS32() != 0) return 0;
+          int32_t val = input1->GetS32() / input2->GetS32();
+          words.push_back(static_cast<uint32_t>(val));
+        } else {
+          if (input1->GetU32() % input2->GetU32() != 0) return 0;
+          uint32_t val = input1->GetU32() / input2->GetU32();
+          words.push_back(val);
+        }
+      }
+      break;
+    case SpvOpIAdd:
+      FOLD_OP(+);
+      break;
+    case SpvOpISub:
+      FOLD_OP(-);
+      break;
+    default:
+      assert(false && "Unexpected operation");
+      break;
+  }
+#undef FOLD_OP
+  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
+  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Integers, Floats or Vectors of such.
+uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
+                          const analysis::Constant* input1,
+                          const analysis::Constant* input2) {
+  assert(input1 && input2);
+  assert(input1->type() == input2->type());
+  const analysis::Type* type = input1->type();
+  std::vector<uint32_t> words;
+  if (const analysis::Vector* vector_type = type->AsVector()) {
+    const analysis::Type* ele_type = vector_type->element_type();
+    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
+      uint32_t id = 0;
+      const analysis::Constant* input1_comp =
+          input1->AsVectorConstant()->GetComponents()[i];
+      const analysis::Constant* input2_comp =
+          input2->AsVectorConstant()->GetComponents()[i];
+      if (ele_type->AsFloat()) {
+        id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
+                                           input2_comp);
+      } else {
+        assert(ele_type->AsInteger());
+        id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
+                                     input2_comp);
+      }
+      if (id == 0) return 0;
+      words.push_back(id);
+    }
+    const analysis::Constant* merged_const =
+        const_mgr->GetConstant(type, words);
+    return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+  } else if (type->AsFloat()) {
+    return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
+  } else {
+    assert(type->AsInteger());
+    return PerformIntegerOperation(const_mgr, opcode, input1, input2);
+  }
+}
+
+// Merges consecutive multiplies where each contains one constant operand.
+// Cases:
+// 2 * (x * 2) = x * 4
+// 2 * (2 * x) = x * 4
+// (x * 2) * 2 = x * 4
+// (2 * x) * 2 = x * 4
+FoldingRule MergeMulMulArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    // Determine the constant input and the variable input in |inst|.
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == inst->opcode()) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      bool other_first_is_variable = other_constants[0] == nullptr;
+      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+                                            const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      uint32_t non_const_id = other_first_is_variable
+                                  ? other_inst->GetSingleWordInOperand(0u)
+                                  : other_inst->GetSingleWordInOperand(1u);
+      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+                           {SPV_OPERAND_TYPE_ID, {merged_id}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Merges divides into subsequent multiplies if each instruction contains one
+// constant operand.
+// Cases:
+// 2 * (x / 2) = 4 / x
+// 2 * (2 / x) = x * 1
+// (x / 2) * 2 = x * 1
+// (2 / x) * 2 = 4 / x
+FoldingRule MergeMulDivArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFDiv ||
+        other_inst->opcode() == SpvOpSDiv ||
+        other_inst->opcode() == SpvOpUDiv) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      bool other_first_is_variable = other_constants[0] == nullptr;
+      // If the variable value is the second operand of the divide, multiply
+      // the constants together. Otherwise divide the constants.
+      uint32_t merged_id = PerformOperation(
+          const_mgr,
+          other_first_is_variable ? other_inst->opcode() : inst->opcode(),
+          const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      uint32_t non_const_id = other_first_is_variable
+                                  ? other_inst->GetSingleWordInOperand(0u)
+                                  : other_inst->GetSingleWordInOperand(1u);
+
+      // If the variable value is on the second operand of the div, then this
+      // operation is a div. Otherwise it should be a multiply.
+      inst->SetOpcode(other_first_is_variable ? inst->opcode()
+                                              : other_inst->opcode());
+      if (other_first_is_variable) {
+        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+                             {SPV_OPERAND_TYPE_ID, {merged_id}}});
+      } else {
+        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
+                             {SPV_OPERAND_TYPE_ID, {non_const_id}}});
+      }
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Merges multiply of constant and negation.
+// Cases:
+// (-x) * 2 = x * -2
+// 2 * (-x) = x * -2
+FoldingRule MergeMulNegateArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFNegate ||
+        other_inst->opcode() == SpvOpSNegate) {
+      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
+
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
+           {SPV_OPERAND_TYPE_ID, {neg_id}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Merges consecutive divides if each instruction contains one constant operand.
+// Cases:
+// 2 / (x / 2) = 4 / x
+// 4 / (2 / x) = 2 * x
+// (4 / x) / 2 = 2 / x
+// (x / 2) / 2 = x / 4
+FoldingRule MergeDivDivArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+           inst->opcode() == SpvOpUDiv);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    bool first_is_variable = constants[0] == nullptr;
+    if (other_inst->opcode() == inst->opcode()) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      bool other_first_is_variable = other_constants[0] == nullptr;
+
+      SpvOp merge_op = inst->opcode();
+      if (other_first_is_variable) {
+        // Constants magnify.
+        merge_op = uses_float ? SpvOpFMul : SpvOpIMul;
+      }
+
+      // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
+      // because it is commutative.
+      if (first_is_variable) std::swap(const_input1, const_input2);
+      uint32_t merged_id =
+          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      uint32_t non_const_id = other_first_is_variable
+                                  ? other_inst->GetSingleWordInOperand(0u)
+                                  : other_inst->GetSingleWordInOperand(1u);
+
+      SpvOp op = inst->opcode();
+      if (!first_is_variable && !other_first_is_variable) {
+        // Effectively div of 1/x, so change to multiply.
+        op = uses_float ? SpvOpFMul : SpvOpIMul;
+      }
+
+      uint32_t op1 = merged_id;
+      uint32_t op2 = non_const_id;
+      if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
+      inst->SetOpcode(op);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Fold multiplies succeeded by divides where each instruction contains a
+// constant operand.
+// Cases:
+// 4 / (x * 2) = 2 / x
+// 4 / (2 * x) = 2 / x
+// (x * 4) / 2 = x * 2
+// (4 * x) / 2 = x * 2
+FoldingRule MergeDivMulArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+           inst->opcode() == SpvOpUDiv);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    bool first_is_variable = constants[0] == nullptr;
+    if (other_inst->opcode() == SpvOpFMul ||
+        other_inst->opcode() == SpvOpIMul) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      bool other_first_is_variable = other_constants[0] == nullptr;
+
+      // This is an x / (*) case. Swap the inputs.
+      if (first_is_variable) std::swap(const_input1, const_input2);
+      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+                                            const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      uint32_t non_const_id = other_first_is_variable
+                                  ? other_inst->GetSingleWordInOperand(0u)
+                                  : other_inst->GetSingleWordInOperand(1u);
+
+      uint32_t op1 = merged_id;
+      uint32_t op2 = non_const_id;
+      if (first_is_variable) std::swap(op1, op2);
+
+      // Convert to multiply
+      if (first_is_variable) inst->SetOpcode(other_inst->opcode());
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Fold divides of a constant and a negation.
+// Cases:
+// (-x) / 2 = x / -2
+// 2 / (-x) = 2 / -x
+FoldingRule MergeDivNegateArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+           inst->opcode() == SpvOpUDiv);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    bool first_is_variable = constants[0] == nullptr;
+    if (other_inst->opcode() == SpvOpFNegate ||
+        other_inst->opcode() == SpvOpSNegate) {
+      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
+
+      if (first_is_variable) {
+        inst->SetInOperands(
+            {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
+             {SPV_OPERAND_TYPE_ID, {neg_id}}});
+      } else {
+        inst->SetInOperands(
+            {{SPV_OPERAND_TYPE_ID, {neg_id}},
+             {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
+      }
+      return true;
+    }
+
+    return false;
+  };
+}
+
+// Folds addition of a constant and a negation.
+// Cases:
+// (-x) + 2 = 2 - x
+// 2 + (-x) = 2 - x
+FoldingRule MergeAddNegateArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpSNegate ||
+        other_inst->opcode() == SpvOpFNegate) {
+      inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
+      uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
+                                       : inst->GetSingleWordInOperand(1u);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {const_id}},
+           {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
+      return true;
+    }
+    return false;
+  };
+}
+
+// Folds subtraction of a constant and a negation.
+// Cases:
+// (-x) - 2 = -2 - x
+// 2 - (-x) = x + 2
+FoldingRule MergeSubNegateArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpSNegate ||
+        other_inst->opcode() == SpvOpFNegate) {
+      uint32_t op1 = 0;
+      uint32_t op2 = 0;
+      SpvOp opcode = inst->opcode();
+      if (constants[0] != nullptr) {
+        op1 = other_inst->GetSingleWordInOperand(0u);
+        op2 = inst->GetSingleWordInOperand(0u);
+        opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
+      } else {
+        op1 = NegateConstant(const_mgr, const_input1);
+        op2 = other_inst->GetSingleWordInOperand(0u);
+      }
+
+      inst->SetOpcode(opcode);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+    return false;
+  };
+}
+
+// Folds addition of an addition where each operation has a constant operand.
+// Cases:
+// (x + 2) + 2 = x + 4
+// (2 + x) + 2 = x + 4
+// 2 + (x + 2) = x + 4
+// 2 + (2 + x) = x + 4
+FoldingRule MergeAddAddArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFAdd ||
+        other_inst->opcode() == SpvOpIAdd) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      ir::Instruction* non_const_input =
+          NonConstInput(context, other_constants[0], other_inst);
+      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+                                            const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
+           {SPV_OPERAND_TYPE_ID, {merged_id}}});
+      return true;
+    }
+    return false;
+  };
+}
+
+// Folds addition of a subtraction where each operation has a constant operand.
+// Cases:
+// (x - 2) + 2 = x + 0
+// (2 - x) + 2 = 4 - x
+// 2 + (x - 2) = x + 0
+// 2 + (2 - x) = 4 - x
+FoldingRule MergeAddSubArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFSub ||
+        other_inst->opcode() == SpvOpISub) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      bool first_is_variable = other_constants[0] == nullptr;
+      SpvOp op = inst->opcode();
+      uint32_t op1 = 0;
+      uint32_t op2 = 0;
+      if (first_is_variable) {
+        // Subtract constants. Non-constant operand is first.
+        op1 = other_inst->GetSingleWordInOperand(0u);
+        op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
+                               const_input2);
+      } else {
+        // Add constants. Constant operand is first. Change the opcode.
+        op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
+                               const_input2);
+        op2 = other_inst->GetSingleWordInOperand(1u);
+        op = other_inst->opcode();
+      }
+      if (op1 == 0 || op2 == 0) return false;
+
+      inst->SetOpcode(op);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+    return false;
+  };
+}
+
+// Folds subtraction of an addition where each operand has a constant operand.
+// Cases:
+// (x + 2) - 2 = x + 0
+// (2 + x) - 2 = x + 0
+// 2 - (x + 2) = 0 - x
+// 2 - (2 + x) = 0 - x
+FoldingRule MergeSubAddArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFAdd ||
+        other_inst->opcode() == SpvOpIAdd) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      ir::Instruction* non_const_input =
+          NonConstInput(context, other_constants[0], other_inst);
+
+      // If the first operand of the sub is not a constant, swap the constants
+      // so the subtraction has the correct operands.
+      if (constants[0] == nullptr) std::swap(const_input1, const_input2);
+      // Subtract the constants.
+      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+                                            const_input1, const_input2);
+      SpvOp op = inst->opcode();
+      uint32_t op1 = 0;
+      uint32_t op2 = 0;
+      if (constants[0] == nullptr) {
+        // Non-constant operand is first. Change the opcode.
+        op1 = non_const_input->result_id();
+        op2 = merged_id;
+        op = other_inst->opcode();
+      } else {
+        // Constant operand is first.
+        op1 = merged_id;
+        op2 = non_const_input->result_id();
+      }
+      if (op1 == 0 || op2 == 0) return false;
+
+      inst->SetOpcode(op);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+    return false;
+  };
+}
+
+// Folds subtraction of a subtraction where each operand has a constant operand.
+// Cases:
+// (x - 2) - 2 = x - 4
+// (2 - x) - 2 = 0 - x
+// 2 - (x - 2) = 4 - x
+// 2 - (2 - x) = x + 0
+FoldingRule MergeSubSubArithmetic() {
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+    ir::IRContext* context = inst->context();
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    bool uses_float = HasFloatingPoint(type);
+    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) return false;
+
+    const analysis::Constant* const_input1 = ConstInput(constants);
+    if (!const_input1) return false;
+    ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+      return false;
+
+    if (other_inst->opcode() == SpvOpFSub ||
+        other_inst->opcode() == SpvOpISub) {
+      std::vector<const analysis::Constant*> other_constants =
+          const_mgr->GetOperandConstants(other_inst);
+      const analysis::Constant* const_input2 = ConstInput(other_constants);
+      if (!const_input2) return false;
+
+      ir::Instruction* non_const_input =
+          NonConstInput(context, other_constants[0], other_inst);
+
+      // Merge the constants.
+      uint32_t merged_id = 0;
+      SpvOp merge_op = inst->opcode();
+      if (other_constants[0] == nullptr) {
+        merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
+      } else if (constants[0] == nullptr) {
+        std::swap(const_input1, const_input2);
+      }
+      merged_id =
+          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
+      if (merged_id == 0) return false;
+
+      SpvOp op = inst->opcode();
+      if (constants[0] != nullptr && other_constants[0] != nullptr) {
+        // Change the operation.
+        op = uses_float ? SpvOpFAdd : SpvOpIAdd;
+      }
+
+      uint32_t op1 = 0;
+      uint32_t op2 = 0;
+      if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
+        op1 = merged_id;
+        op2 = non_const_input->result_id();
+      } else {
+        op1 = non_const_input->result_id();
+        op2 = merged_id;
+      }
+
+      inst->SetOpcode(op);
+      inst->SetInOperands(
+          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+      return true;
+    }
+    return false;
+  };
+}
+
 FoldingRule IntMultipleBy1() {
   return [](ir::Instruction* inst,
             const std::vector<const analysis::Constant*>& constants) {
@@ -36,11 +1219,17 @@ FoldingRule IntMultipleBy1() {
         continue;
       }
       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
-      if (int_constant && int_constant->GetU32BitValue() == 1) {
-        inst->SetOpcode(SpvOpCopyObject);
-        inst->SetInOperands(
-            {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
-        return true;
+      if (int_constant) {
+        uint32_t width = ElementWidth(int_constant->type());
+        if (width != 32 && width != 64) return false;
+        bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
+                                    : int_constant->GetU64BitValue() == 1ull;
+        if (is_one) {
+          inst->SetOpcode(SpvOpCopyObject);
+          inst->SetInOperands(
+              {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
+          return true;
+        }
       }
     }
     return false;
@@ -540,15 +1729,58 @@ spvtools::opt::FoldingRules::FoldingRules() {
   rules_[SpvOpExtInst].push_back(RedundantFMix());
 
   rules_[SpvOpFAdd].push_back(RedundantFAdd());
+  rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
+  rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
+  rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
+
   rules_[SpvOpFDiv].push_back(RedundantFDiv());
+  rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
+  rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
+  rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
+  rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
+
   rules_[SpvOpFMul].push_back(RedundantFMul());
+  rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
+  rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
+  rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
+
+  rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
+  rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
+  rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
+
   rules_[SpvOpFSub].push_back(RedundantFSub());
+  rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
+  rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
+  rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
+
+  rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
+  rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
+  rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
 
   rules_[SpvOpIMul].push_back(IntMultipleBy1());
+  rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
+  rules_[SpvOpIMul].push_back(MergeMulDivArithmetic());
+  rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
+
+  rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
+  rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
+  rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
 
   rules_[SpvOpPhi].push_back(RedundantPhi());
 
+  rules_[SpvOpSDiv].push_back(MergeDivDivArithmetic());
+  rules_[SpvOpSDiv].push_back(MergeDivMulArithmetic());
+  rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
+
+  rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
+  rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
+  rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
+
   rules_[SpvOpSelect].push_back(RedundantSelect());
+
+  rules_[SpvOpUDiv].push_back(MergeDivDivArithmetic());
+  rules_[SpvOpUDiv].push_back(MergeDivMulArithmetic());
+  rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
 }
 
 }  // namespace opt
index 012b3b7..9825a31 100644 (file)
@@ -22,6 +22,7 @@
 #include <iomanip>
 #include <limits>
 #include <sstream>
+#include <vector>
 
 #include "bitutils.h"
 
@@ -82,6 +83,8 @@ struct FloatProxyTraits<float> {
   static uint_type getBitsFromFloat(const float& t) {
     return BitwiseCast<uint_type>(t);
   }
+  // Returns the bitwidth.
+  static uint32_t width() { return 32u; }
 };
 
 template <>
@@ -102,6 +105,8 @@ struct FloatProxyTraits<double> {
   static uint_type getBitsFromFloat(const double& t) {
     return BitwiseCast<uint_type>(t);
   }
+  // Returns the bitwidth.
+  static uint32_t width() { return 64u; }
 };
 
 template <>
@@ -118,6 +123,8 @@ struct FloatProxyTraits<Float16> {
   static Float16 getAsFloat(const uint_type& t) { return Float16(t); }
   // Returns the bits from the given floating pointer number.
   static uint_type getBitsFromFloat(const Float16& t) { return t.get_value(); }
+  // Returns the bitwidth.
+  static uint32_t width() { return 16u; }
 };
 
 // Since copying a floating point number (especially if it is NaN)
@@ -152,6 +159,19 @@ class FloatProxy {
   // Returns the raw data.
   uint_type data() const { return data_; }
 
+  // Returns a vector of words suitable for use in an Operand.
+  std::vector<uint32_t> GetWords() const {
+    std::vector<uint32_t> words;
+    if (FloatProxyTraits<T>::width() == 64) {
+      FloatProxyTraits<double>::uint_type d = data();
+      words.push_back(static_cast<uint32_t>(d));
+      words.push_back(static_cast<uint32_t>(d >> 32));
+    } else {
+      words.push_back(static_cast<uint32_t>(data()));
+    }
+    return words;
+  }
+
   // Returns true if the value represents any type of NaN.
   bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
   // Returns true if the value represents any type of infinity.
index 8cb7dcd..0fa1d15 100644 (file)
 #include <gtest/gtest.h>
 #include <opt/fold.h>
 
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
 #include "opt/build_module.h"
 #include "opt/def_use_manager.h"
 #include "opt/ir_context.h"
@@ -32,6 +36,31 @@ using ::testing::Contains;
 using namespace spvtools;
 using spvtools::opt::analysis::DefUseManager;
 
+std::string Disassemble(const std::string& original, ir::IRContext* context,
+                        uint32_t disassemble_options = 0) {
+  std::vector<uint32_t> optimized_bin;
+  context->module()->ToBinary(&optimized_bin, true);
+  spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+  SpirvTools tools(target_env);
+  std::string optimized_asm;
+  EXPECT_TRUE(
+      tools.Disassemble(optimized_bin, &optimized_asm, disassemble_options))
+      << "Disassembling failed for shader:\n"
+      << original << std::endl;
+  return optimized_asm;
+}
+
+#ifdef SPIRV_EFFCEE
+void Match(const std::string& original, ir::IRContext* context,
+           uint32_t disassemble_options = 0) {
+  std::string disassembly = Disassemble(original, context, disassemble_options);
+  auto match_result = effcee::Match(disassembly, original);
+  EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+      << match_result.message() << "\nChecking result:\n"
+      << disassembly;
+}
+#endif
+
 template <class ResultType>
 struct InstructionFoldingCase {
   InstructionFoldingCase(const std::string& tb, uint32_t id, ResultType result)
@@ -104,36 +133,49 @@ OpName %main "main"
 %short = OpTypeInt 16 1
 %int = OpTypeInt 32 1
 %long = OpTypeInt 64 1
-%uint = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
 %v2int = OpTypeVector %int 2
 %v4int = OpTypeVector %int 4
 %v4float = OpTypeVector %float 4
 %v4double = OpTypeVector %double 4
+%v2float = OpTypeVector %float 2
 %struct_v2int_int_int = OpTypeStruct %v2int %int %int
 %_ptr_int = OpTypePointer Function %int
 %_ptr_uint = OpTypePointer Function %uint
 %_ptr_bool = OpTypePointer Function %bool
 %_ptr_float = OpTypePointer Function %float
 %_ptr_double = OpTypePointer Function %double
+%_ptr_long = OpTypePointer Function %long
+%_ptr_v2int = OpTypePointer Function %v2int
 %_ptr_v4float = OpTypePointer Function %v4float
 %_ptr_v4double = OpTypePointer Function %v4double
 %_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
+%_ptr_v2float = OpTypePointer Function %v2float
 %short_0 = OpConstant %short 0
 %short_3 = OpConstant %short 3
 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
 %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
 %int_0 = OpConstant %int 0
 %int_1 = OpConstant %int 1
+%int_2 = OpConstant %int 2
 %int_3 = OpConstant %int 3
+%int_4 = OpConstant %int 4
 %int_min = OpConstant %int -2147483648
 %int_max = OpConstant %int 2147483647
 %long_0 = OpConstant %long 0
+%long_2 = OpConstant %long 2
 %long_3 = OpConstant %long 3
 %uint_0 = OpConstant %uint 0
+%uint_2 = OpConstant %uint 2
 %uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
 %uint_32 = OpConstant %uint 32
-%uint_max = OpConstant %uint -1
+%uint_max = OpConstant %uint 4294967295
 %v2int_undef = OpUndef %v2int
+%v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2
+%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
+%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
+%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
 %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
 %102 = OpConstantComposite %v2int %103 %103
 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
@@ -147,6 +189,11 @@ OpName %main "main"
 %float_1 = OpConstant %float 1
 %float_2 = OpConstant %float 2
 %float_3 = OpConstant %float 3
+%float_4 = OpConstant %float 4
+%float_0p5 = OpConstant %float 0.5
+%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3
+%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
+%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
 %double_n1 = OpConstant %double -1
 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
 %double_0 = OpConstant %double 0
@@ -2264,17 +2311,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 0),
-    // Test case 3: Don't fold n / 2.0
-    InstructionFoldingCase<uint32_t>(
-        Header() + "%main = OpFunction %void None %void_func\n" +
-            "%main_lab = OpLabel\n" +
-            "%n = OpVariable %_ptr_float Function\n" +
-            "%3 = OpLoad %float %n\n" +
-            "%2 = OpFDiv %float %3 %float_2\n" +
-            "OpReturn\n" +
-            "OpFunctionEnd",
-        2, 0),
-    // Test case 4: Fold n + 0.0
+    // Test case 3: Fold n + 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2284,7 +2321,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 5: Fold 0.0 + n
+    // Test case 4: Fold 0.0 + n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2294,7 +2331,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 6: Fold n - 0.0
+    // Test case 5: Fold n - 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2304,7 +2341,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 7: Fold n * 1.0
+    // Test case 6: Fold n * 1.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2314,7 +2351,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 8: Fold 1.0 * n
+    // Test case 7: Fold 1.0 * n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2324,7 +2361,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 9: Fold n / 1.0
+    // Test case 8: Fold n / 1.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2334,7 +2371,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 10: Fold n * 0.0
+    // Test case 9: Fold n * 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2344,7 +2381,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, FLOAT_0_ID),
-    // Test case 11: Fold 0.0 * n
+    // Test case 10: Fold 0.0 * n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2354,7 +2391,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, FLOAT_0_ID),
-    // Test case 12: Fold 0.0 / n
+    // Test case 11: Fold 0.0 / n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2364,7 +2401,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, FLOAT_0_ID),
-    // Test case 13: Don't fold mix(a, b, 2.0)
+    // Test case 12: Don't fold mix(a, b, 2.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2376,7 +2413,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 0),
-    // Test case 14: Fold mix(a, b, 0.0)
+    // Test case 13: Fold mix(a, b, 0.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2388,7 +2425,7 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 15: Fold mix(a, b, 1.0)
+    // Test case 14: Fold mix(a, b, 1.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2434,17 +2471,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 0),
-    // Test case 3: Don't fold n / 2.0
-    InstructionFoldingCase<uint32_t>(
-        Header() + "%main = OpFunction %void None %void_func\n" +
-            "%main_lab = OpLabel\n" +
-            "%n = OpVariable %_ptr_double Function\n" +
-            "%3 = OpLoad %double %n\n" +
-            "%2 = OpFDiv %double %3 %double_2\n" +
-            "OpReturn\n" +
-            "OpFunctionEnd",
-        2, 0),
-    // Test case 4: Fold n + 0.0
+    // Test case 3: Fold n + 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2454,7 +2481,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 5: Fold 0.0 + n
+    // Test case 4: Fold 0.0 + n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2464,7 +2491,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 6: Fold n - 0.0
+    // Test case 5: Fold n - 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2474,7 +2501,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 7: Fold n * 1.0
+    // Test case 6: Fold n * 1.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2484,7 +2511,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 8: Fold 1.0 * n
+    // Test case 7: Fold 1.0 * n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2494,7 +2521,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 9: Fold n / 1.0
+    // Test case 8: Fold n / 1.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2504,7 +2531,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 10: Fold n * 0.0
+    // Test case 9: Fold n * 0.0
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2514,7 +2541,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, DOUBLE_0_ID),
-    // Test case 11: Fold 0.0 * n
+    // Test case 10: Fold 0.0 * n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2524,7 +2551,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, DOUBLE_0_ID),
-    // Test case 12: Fold 0.0 / n
+    // Test case 11: Fold 0.0 / n
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2534,7 +2561,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, DOUBLE_0_ID),
-    // Test case 13: Don't fold mix(a, b, 2.0)
+    // Test case 12: Don't fold mix(a, b, 2.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2546,7 +2573,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 0),
-    // Test case 14: Fold mix(a, b, 0.0)
+    // Test case 13: Fold mix(a, b, 0.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2558,7 +2585,7 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTes
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 3),
-    // Test case 15: Fold mix(a, b, 1.0)
+    // Test case 14: Fold mix(a, b, 1.0)
     InstructionFoldingCase<uint32_t>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -2762,6 +2789,1381 @@ INSTANTIATE_TEST_CASE_P(DoubleRedundantSubFoldingTest, ToNegateFoldingTest,
             "OpFunctionEnd",
         2, 3)
 ));
-// clang-format on
 
+#ifdef SPIRV_EFFCEE
+using MatchingInstructionFoldingTest =
+    ::testing::TestWithParam<InstructionFoldingCase<bool>>;
+
+TEST_P(MatchingInstructionFoldingTest, Case) {
+  const auto& tc = GetParam();
+
+  // Build module.
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  // Fold the instruction to test.
+  opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+  std::unique_ptr<ir::Instruction> original_inst(inst->Clone(context.get()));
+  bool succeeded = opt::FoldInstruction(inst);
+  EXPECT_EQ(succeeded, tc.expected_result);
+  if (succeeded) {
+    Match(tc.test_body, context.get());
+  }
+}
+
+INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: fold consecutive fnegate
+  // -(-x) = x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float:%\\w+]]\n" +
+      "; CHECK: %4 = OpCopyObject [[float]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFNegate %float %2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 1: fold fnegate(fmul with const).
+  // -(x * 2.0) = x * -2.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFMul %float %2 %float_2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 2: fold fnegate(fmul with const).
+  // -(2.0 * x) = x * 2.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFMul %float %float_2 %2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 3: fold fnegate(fdiv with const).
+  // -(x / 2.0) = x * -0.5
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n0p5:%\\w+]] = OpConstant [[float]] -0.5\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n0p5]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %2 %float_2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 4: fold fnegate(fdiv with const).
+  // -(2.0 / x) = -2.0 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFDiv [[float]] [[float_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %float_2 %2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 5: fold fnegate(fadd with const).
+  // -(2.0 + x) = -2.0 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %float_2 %2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 6: fold fnegate(fadd with const).
+  // -(x + 2.0) = -2.0 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %2 %float_2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 7: fold fnegate(fsub with const).
+  // -(2.0 - x) = x - 2.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_2 %2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 8: fold fnegate(fsub with const).
+  // -(x - 2.0) = 2.0 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %2 %float_2\n" +
+      "%4 = OpFNegate %float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 9: fold consecutive snegate
+  // -(-x) = x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int:%\\w+]]\n" +
+      "; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSNegate %int %2\n" +
+      "%4 = OpSNegate %int %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 10: fold consecutive vector negate
+  // -(-x) = x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float:%\\w+]]\n" +
+      "; CHECK: %4 = OpCopyObject [[v2float]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2float Function\n" +
+      "%2 = OpLoad %v2float %var\n" +
+      "%3 = OpFNegate %v2float %2\n" +
+      "%4 = OpFNegate %v2float %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 11: fold snegate(iadd with const).
+  // -(2 + x) = -2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIAdd %int %int_2 %2\n" +
+      "%4 = OpSNegate %int %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 12: fold snegate(iadd with const).
+  // -(x + 2) = -2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIAdd %int %2 %int_2\n" +
+      "%4 = OpSNegate %int %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 13: fold snegate(isub with const).
+  // -(2 - x) = x - 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpISub [[int]] [[ld]] [[int_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpISub %int %int_2 %2\n" +
+      "%4 = OpSNegate %int %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 14: fold snegate(isub with const).
+  // -(x - 2) = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpISub [[int]] [[int_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpISub %int %2 %int_2\n" +
+      "%4 = OpSNegate %int %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 15: fold snegate(iadd with const).
+  // -(x + 2) = -2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpIAdd %long %2 %long_2\n" +
+      "%4 = OpSNegate %long %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 16: fold snegate(isub with const).
+  // -(2 - x) = x - 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[ld]] [[long_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpISub %long %long_2 %2\n" +
+      "%4 = OpSNegate %long %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true),
+  // Test case 17: fold snegate(isub with const).
+  // -(x - 2) = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpISub %long %2 %long_2\n" +
+      "%4 = OpSNegate %long %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd",
+    4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: scalar reicprocal
+  // x / 0.5 = x * 2.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %3 = OpFMul [[float]] [[ld]] [[float_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %2 %float_0p5\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    3, true),
+  // Test case 1: Unfoldable
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_0:%\\w+]] = OpConstant [[float]] 0\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %3 = OpFDiv [[float]] [[ld]] [[float_0]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %2 %104\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    3, false),
+  // Test case 2: Vector reciprocal
+  // x / {2.0, 0.5} = x * {0.5, 2.0}
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[float_0p5:%\\w+]] = OpConstant [[float]] 0.5\n" +
+      "; CHECK: [[v2float_0p5_2:%\\w+]] = OpConstantComposite [[v2float]] [[float_0p5]] [[float_2]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" +
+      "; CHECK: %3 = OpFMul [[v2float]] [[ld]] [[v2float_0p5_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2float Function\n" +
+      "%2 = OpLoad %v2float %var\n" +
+      "%3 = OpFDiv %v2float %2 %v2float_2_0p5\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    3, true),
+  // Test case 3: double reciprocal
+  // x / 2.0 = x * 0.5
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+      "; CHECK: [[double_0p5:%\\w+]] = OpConstant [[double]] 0.5\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" +
+      "; CHECK: %3 = OpFMul [[double]] [[ld]] [[double_0p5]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_double Function\n" +
+      "%2 = OpLoad %double %var\n" +
+      "%3 = OpFDiv %double %2 %double_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    3, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: fold consecutive fmuls
+  // (x * 3.0) * 2.0 = x * 6.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFMul %float %2 %float_3\n" +
+      "%4 = OpFMul %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 1: fold consecutive fmuls
+  // 2.0 * (x * 3.0) = x * 6.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFMul %float %2 %float_3\n" +
+      "%4 = OpFMul %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 2: fold consecutive fmuls
+  // (3.0 * x) * 2.0 = x * 6.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFMul %float %float_3 %2\n" +
+      "%4 = OpFMul %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 3: fold vector fmul
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+      "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+      "; CHECK: [[v2float_6_6:%\\w+]] = OpConstantComposite [[v2float]] [[float_6]] [[float_6]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" +
+      "; CHECK: %4 = OpFMul [[v2float]] [[ld]] [[v2float_6_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2float Function\n" +
+      "%2 = OpLoad %v2float %var\n" +
+      "%3 = OpFMul %v2float %2 %v2float_2_3\n" +
+      "%4 = OpFMul %v2float %3 %v2float_3_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 4: fold double fmuls
+  // (x * 3.0) * 2.0 = x * 6.0
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+      "; CHECK: [[double_6:%\\w+]] = OpConstant [[double]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" +
+      "; CHECK: %4 = OpFMul [[double]] [[ld]] [[double_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_double Function\n" +
+      "%2 = OpLoad %double %var\n" +
+      "%3 = OpFMul %double %2 %double_3\n" +
+      "%4 = OpFMul %double %3 %double_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 5: fold 32 bit imuls
+  // (x * 3) * 2 = x * 6
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIMul %int %2 %int_3\n" +
+      "%4 = OpIMul %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 6: fold 64 bit imuls
+  // (x * 3) * 2 = x * 6
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
+      "; CHECK: [[long_6:%\\w+]] = OpConstant [[long]] 6\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpIMul %long %2 %long_3\n" +
+      "%4 = OpIMul %long %3 %long_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 7: merge vector integer mults
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+      "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" +
+      "; CHECK: [[v2int_6_6:%\\w+]] = OpConstantComposite [[v2int]] [[int_6]] [[int_6]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+      "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_6_6]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2int Function\n" +
+      "%2 = OpLoad %v2int %var\n" +
+      "%3 = OpIMul %v2int %2 %v2int_2_3\n" +
+      "%4 = OpIMul %v2int %3 %v2int_3_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 8: merge fmul of fdiv
+  // 2.0 * (2.0 / x) = 4.0 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %float_2 %2\n" +
+      "%4 = OpFMul %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 9: merge fmul of fdiv
+  // (2.0 / x) * 2.0 = 4.0 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %float_2 %2\n" +
+      "%4 = OpFMul %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 10: merge imul of sdiv
+  // 4 * (x / 2) = 2 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %2 %int_2\n" +
+      "%4 = OpIMul %int %int_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 11: merge imul of sdiv
+  // (x / 2) * 4 = 2 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %2 %int_2\n" +
+      "%4 = OpIMul %int %3 %int_4\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 12: merge imul of udiv
+  // 4 * (x / 2) = 2 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
+      "; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
+      "; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_uint Function\n" +
+      "%2 = OpLoad %uint %var\n" +
+      "%3 = OpUDiv %uint %2 %uint_2\n" +
+      "%4 = OpIMul %uint %uint_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 13: merge imul of udiv
+  // (x / 2) * 4 = 2 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
+      "; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
+      "; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_uint Function\n" +
+      "%2 = OpLoad %uint %var\n" +
+      "%3 = OpUDiv %uint %2 %uint_2\n" +
+      "%4 = OpIMul %uint %3 %uint_4\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 14: Don't fold if would have remainder
+  // (x / 3) * 4 
+  InstructionFoldingCase<bool>(
+    Header() +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_uint Function\n" +
+      "%2 = OpLoad %uint %var\n" +
+      "%3 = OpUDiv %uint %2 %uint_3\n" +
+      "%4 = OpIMul %uint %3 %uint_4\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, false),
+  // Test case 15: merge vector imul of sdiv
+  // (x / {2,2}) * {4,4} = x * {2,2}
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[v2int_2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int_2]] [[int_2]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+      "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_2_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2int Function\n" +
+      "%2 = OpLoad %v2int %var\n" +
+      "%3 = OpSDiv %v2int %2 %v2int_2_2\n" +
+      "%4 = OpIMul %v2int %3 %v2int_4_4\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 15: merge vector imul of snegate
+  // (-x) * {2,2} = x * {-2,-2}
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+      "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2int Function\n" +
+      "%2 = OpLoad %v2int %var\n" +
+      "%3 = OpSNegate %v2int %2\n" +
+      "%4 = OpIMul %v2int %3 %v2int_2_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 15: merge vector imul of snegate
+  // {2,2} * (-x) = x * {-2,-2}
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+      "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2int Function\n" +
+      "%2 = OpLoad %v2int %var\n" +
+      "%3 = OpSNegate %v2int %2\n" +
+      "%4 = OpIMul %v2int %v2int_2_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: merge consecutive fdiv
+  // 4.0 / (2.0 / x) = 2.0 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFMul [[float]] [[float_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %float_2 %2\n" +
+      "%4 = OpFDiv %float %float_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 1: merge consecutive fdiv
+  // 4.0 / (x / 2.0) = 8.0 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_8:%\\w+]] = OpConstant [[float]] 8\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFDiv [[float]] [[float_8]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %2 %float_2\n" +
+      "%4 = OpFDiv %float %float_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 2: merge consecutive fdiv
+  // (4.0 / x) / 2.0 = 2.0 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFDiv [[float]] [[float_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %float_4 %2\n" +
+      "%4 = OpFDiv %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 3: merge consecutive sdiv
+  // 4 / (2 / x) = 2 * x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[int_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %int_2 %2\n" +
+      "%4 = OpSDiv %int %int_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 4: merge consecutive sdiv
+  // 4 / (x / 2) = 8 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_8:%\\w+]] = OpConstant [[int]] 8\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[int_8]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %2 %int_2\n" +
+      "%4 = OpSDiv %int %int_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 5: merge consecutive sdiv
+  // (4 / x) / 2 = 2 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %int_4 %2\n" +
+      "%4 = OpSDiv %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 6: merge consecutive sdiv
+  // (x / 4) / 2 = x / 8
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_8:%\\w+]] = OpConstant [[int]] 8\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_8]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSDiv %int %2 %int_4\n" +
+      "%4 = OpSDiv %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 7: merge sdiv of imul
+  // 4 / (2 * x) = 2 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIMul %int %int_2 %2\n" +
+      "%4 = OpSDiv %int %int_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 8: merge sdiv of imul
+  // 4 / (x * 2) = 2 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIMul %int %2 %int_2\n" +
+      "%4 = OpSDiv %int %int_4 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 9: merge sdiv of imul
+  // (4 * x) / 2 = x * 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIMul %int %int_4 %2\n" +
+      "%4 = OpSDiv %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 10: merge sdiv of imul
+  // (x * 4) / 2 = x * 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpIMul %int %2 %int_4\n" +
+      "%4 = OpSDiv %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 11: merge sdiv of snegate
+  // (-x) / 2 = x / -2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_n2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSNegate %int %2\n" +
+      "%4 = OpSDiv %int %3 %int_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 12: merge sdiv of snegate
+  // 2 / (-x) = -2 / x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+      "; CHECK: OpConstant [[int]] -2147483648\n" +
+      "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+      "; CHECK: %4 = OpSDiv [[int]] [[int_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_int Function\n" +
+      "%2 = OpLoad %int %var\n" +
+      "%3 = OpSNegate %int %2\n" +
+      "%4 = OpSDiv %int %int_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: merge add of negate
+  // (-x) + 2 = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFNegate %float %2\n" +
+      "%4 = OpFAdd %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 1: merge add of negate
+  // 2 + (-x) = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpSNegate %float %2\n" +
+      "%4 = OpIAdd %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 2: merge add of negate
+  // (-x) + 2 = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpSNegate %long %2\n" +
+      "%4 = OpIAdd %long %3 %long_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 3: merge add of negate
+  // 2 + (-x) = 2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpSNegate %long %2\n" +
+      "%4 = OpIAdd %long %long_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 4: merge add of subtract
+  // (x - 1) + 2 = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %2 %float_1\n" +
+      "%4 = OpFAdd %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 5: merge add of subtract
+  // (1 - x) + 2 = 3 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_1 %2\n" +
+      "%4 = OpFAdd %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 6: merge add of subtract
+  // 2 + (x - 1) = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %2 %float_1\n" +
+      "%4 = OpFAdd %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 7: merge add of subtract
+  // 2 + (1 - x) = 3 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_1 %2\n" +
+      "%4 = OpFAdd %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 8: merge add of add
+  // (x + 1) + 2 = x + 3
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %2 %float_1\n" +
+      "%4 = OpFAdd %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 9: merge add of add
+  // (1 + x) + 2 = 3 + x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %float_1 %2\n" +
+      "%4 = OpFAdd %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 10: merge add of add
+  // 2 + (x + 1) = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %2 %float_1\n" +
+      "%4 = OpFAdd %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 11: merge add of add
+  // 2 + (1 + x) = 3 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %float_1 %2\n" +
+      "%4 = OpFAdd %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest,
+::testing::Values(
+  // Test case 0: merge sub of negate
+  // (-x) - 2 = -2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFNegate %float %2\n" +
+      "%4 = OpFSub %float %3 %float_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 1: merge sub of negate
+  // 2 - (-x) = x + 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFNegate %float %2\n" +
+      "%4 = OpFSub %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 2: merge sub of negate
+  // (-x) - 2 = -2 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpSNegate %long %2\n" +
+      "%4 = OpISub %long %3 %long_2\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 3: merge sub of negate
+  // 2 - (-x) = x + 2
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+      "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+      "; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_2]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_long Function\n" +
+      "%2 = OpLoad %long %var\n" +
+      "%3 = OpSNegate %long %2\n" +
+      "%4 = OpISub %long %long_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 4: merge add of subtract
+  // (x + 2) - 1 = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %2 %float_2\n" +
+      "%4 = OpFSub %float %3 %float_1\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 5: merge add of subtract
+  // (2 + x) - 1 = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %float_2 %2\n" +
+      "%4 = OpFSub %float %3 %float_1\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 6: merge add of subtract
+  // 2 - (x + 1) = 1 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %2 %float_1\n" +
+      "%4 = OpFSub %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 7: merge add of subtract
+  // 2 - (1 + x) = 1 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFAdd %float %float_1 %2\n" +
+      "%4 = OpFSub %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 8: merge subtract of subtract
+  // (x - 2) - 1 = x - 3
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_3]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %2 %float_2\n" +
+      "%4 = OpFSub %float %3 %float_1\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 9: merge subtract of subtract
+  // (2 - x) - 1 = 1 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_2 %2\n" +
+      "%4 = OpFSub %float %3 %float_1\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 10: merge subtract of subtract
+  // 2 - (x - 1) = 3 - x
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %2 %float_1\n" +
+      "%4 = OpFSub %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 11: merge subtract of subtract
+  // 1 - (2 - x) = x + (-1)
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_n1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_2 %2\n" +
+      "%4 = OpFSub %float %float_1 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true),
+  // Test case 12: merge subtract of subtract
+  // 2 - (1 - x) = x + 1
+  InstructionFoldingCase<bool>(
+    Header() +
+      "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+      "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+      "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+      "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFSub %float %float_1 %2\n" +
+      "%4 = OpFSub %float %float_2 %3\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, true)
+));
+#endif
 }  // anonymous namespace