const uint32_t kFMixXIdInIdx = 2;
const uint32_t kFMixYIdInIdx = 3;
+// Returns the element width of |type|.
+uint32_t ElementWidth(const analysis::Type* type) {
+ if (const analysis::Vector* vec_type = type->AsVector()) {
+ return ElementWidth(vec_type->element_type());
+ } else if (const analysis::Float* float_type = type->AsFloat()) {
+ return float_type->width();
+ } else {
+ assert(type->AsInteger());
+ return type->AsInteger()->width();
+ }
+}
+
+// Returns true if |type| is Float or a vector of Float.
+bool HasFloatingPoint(const analysis::Type* type) {
+ if (type->AsFloat()) {
+ return true;
+ } else if (const analysis::Vector* vec_type = type->AsVector()) {
+ return vec_type->element_type()->AsFloat() != nullptr;
+ }
+
+ return false;
+}
+
+// Returns false if |val| is NaN, infinite or subnormal.
+template <typename T>
+bool IsValidResult(T val) {
+ int classified = std::fpclassify(val);
+ switch (classified) {
+ case FP_NAN:
+ case FP_INFINITE:
+ case FP_SUBNORMAL:
+ return false;
+ default:
+ return true;
+ }
+}
+
+const analysis::Constant* ConstInput(
+ const std::vector<const analysis::Constant*>& constants) {
+ return constants[0] ? constants[0] : constants[1];
+}
+
+ir::Instruction* NonConstInput(ir::IRContext* context,
+ const analysis::Constant* c,
+ ir::Instruction* inst) {
+ uint32_t in_op = c ? 1u : 0u;
+ return context->get_def_use_mgr()->GetDef(
+ inst->GetSingleWordInOperand(in_op));
+}
+
+// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
+// constant.
+uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
+ const analysis::Constant* c) {
+ assert(c);
+ assert(c->type()->AsFloat());
+ uint32_t width = c->type()->AsFloat()->width();
+ assert(width == 32 || width == 64);
+ std::vector<uint32_t> words;
+ if (width == 64) {
+ spvutils::FloatProxy<double> result(c->GetDouble() * -1.0);
+ words = result.GetWords();
+ } else {
+ spvutils::FloatProxy<float> result(c->GetFloat() * -1.0f);
+ words = result.GetWords();
+ }
+
+ const analysis::Constant* negated_const =
+ const_mgr->GetConstant(c->type(), std::move(words));
+ return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+std::vector<uint32_t> ExtractInts(uint64_t val) {
+ std::vector<uint32_t> words;
+ words.push_back(static_cast<uint32_t>(val));
+ words.push_back(static_cast<uint32_t>(val >> 32));
+ return words;
+}
+
+// Negates the integer constant |c|. Returns the id of the defining instruction.
+uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
+ const analysis::Constant* c) {
+ assert(c);
+ assert(c->type()->AsInteger());
+ uint32_t width = c->type()->AsInteger()->width();
+ assert(width == 32 || width == 64);
+ std::vector<uint32_t> words;
+ if (width == 64) {
+ uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
+ words = ExtractInts(uval);
+ } else {
+ words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
+ }
+
+ const analysis::Constant* negated_const =
+ const_mgr->GetConstant(c->type(), std::move(words));
+ return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+// Negates the vector constant |c|. Returns the id of the defining instruction.
+uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
+ const analysis::Constant* c) {
+ assert(const_mgr && c);
+ assert(c->type()->AsVector());
+ if (c->AsNullConstant()) {
+ // 0.0 vs -0.0 shouldn't matter.
+ return const_mgr->GetDefiningInstruction(c)->result_id();
+ } else {
+ const analysis::Type* component_type =
+ c->AsVectorConstant()->component_type();
+ std::vector<uint32_t> words;
+ for (auto& comp : c->AsVectorConstant()->GetComponents()) {
+ if (component_type->AsFloat()) {
+ words.push_back(NegateFloatingPointConstant(const_mgr, comp));
+ } else {
+ assert(component_type->AsInteger());
+ words.push_back(NegateIntegerConstant(const_mgr, comp));
+ }
+ }
+
+ const analysis::Constant* negated_const =
+ const_mgr->GetConstant(c->type(), std::move(words));
+ return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+ }
+}
+
+// Negates |c|. Returns the id of the defining instruction.
+uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
+ const analysis::Constant* c) {
+ if (c->type()->AsVector()) {
+ return NegateVectorConstant(const_mgr, c);
+ } else if (c->type()->AsFloat()) {
+ return NegateFloatingPointConstant(const_mgr, c);
+ } else {
+ assert(c->type()->AsInteger());
+ return NegateIntegerConstant(const_mgr, c);
+ }
+}
+
+// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
+// Returns 0 if the reciprocal is NaN, infinite or subnormal.
+uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
+ const analysis::Constant* c) {
+ assert(const_mgr && c);
+ assert(c->type()->AsFloat());
+
+ uint32_t width = c->type()->AsFloat()->width();
+ assert(width == 32 || width == 64);
+ std::vector<uint32_t> words;
+ if (width == 64) {
+ spvutils::FloatProxy<double> result(1.0 / c->GetDouble());
+ if (!IsValidResult(result.getAsFloat())) return 0;
+ words = result.GetWords();
+ } else {
+ spvutils::FloatProxy<float> result(1.0f / c->GetFloat());
+ if (!IsValidResult(result.getAsFloat())) return 0;
+ words = result.GetWords();
+ }
+
+ const analysis::Constant* negated_const =
+ const_mgr->GetConstant(c->type(), std::move(words));
+ return const_mgr->GetDefiningInstruction(negated_const)->result_id();
+}
+
+// Replaces fdiv where second operand is constant with fmul.
+FoldingRule ReciprocalFDiv() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFDiv);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (!inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ if (constants[1] != nullptr) {
+ uint32_t id = 0;
+ if (const analysis::VectorConstant* vector_const =
+ constants[1]->AsVectorConstant()) {
+ std::vector<uint32_t> neg_ids;
+ for (auto& comp : vector_const->GetComponents()) {
+ id = Reciprocal(const_mgr, comp);
+ if (id == 0) return false;
+ neg_ids.push_back(id);
+ }
+ const analysis::Constant* negated_const =
+ const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
+ id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
+ } else {
+ id = Reciprocal(const_mgr, constants[1]);
+ if (id == 0) return false;
+ }
+ inst->SetOpcode(SpvOpFMul);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
+ {SPV_OPERAND_TYPE_ID, {id}}});
+ return true;
+ }
+
+ return false;
+ };
+};
+
+// Elides consecutive negate instructions.
+FoldingRule MergeNegateArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+ (void)constants;
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ ir::Instruction* op_inst =
+ context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+ if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (op_inst->opcode() == inst->opcode()) {
+ // Elide negates.
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Merges negate into a mul or div operation if that operation contains a
+// constant operand.
+// Cases:
+// -(x * 2) = x * -2
+// -(2 * x) = x * -2
+// -(x / 2) = x / -2
+// -(2 / x) = -2 / x
+FoldingRule MergeNegateMulDivArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+ (void)constants;
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ ir::Instruction* op_inst =
+ context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+ if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ SpvOp opcode = op_inst->opcode();
+ if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
+ opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
+ std::vector<const analysis::Constant*> op_constants =
+ const_mgr->GetOperandConstants(op_inst);
+ // Merge negate into mul or div if one operand is constant.
+ if (op_constants[0] || op_constants[1]) {
+ bool zero_is_variable = op_constants[0] == nullptr;
+ const analysis::Constant* c = ConstInput(op_constants);
+ uint32_t neg_id = NegateConstant(const_mgr, c);
+ uint32_t non_const_id = zero_is_variable
+ ? op_inst->GetSingleWordInOperand(0u)
+ : op_inst->GetSingleWordInOperand(1u);
+ // Change this instruction to a mul/div.
+ inst->SetOpcode(op_inst->opcode());
+ if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
+ uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
+ uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
+ } else {
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+ {SPV_OPERAND_TYPE_ID, {neg_id}}});
+ }
+ return true;
+ }
+ }
+
+ return false;
+ };
+}
+
+// Merges negate into a add or sub operation if that operation contains a
+// constant operand.
+// Cases:
+// -(x + 2) = -2 - x
+// -(2 + x) = -2 - x
+// -(x - 2) = 2 - x
+// -(2 - x) = x - 2
+FoldingRule MergeNegateAddSubArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
+ (void)constants;
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ ir::Instruction* op_inst =
+ context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
+ if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
+ op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
+ std::vector<const analysis::Constant*> op_constants =
+ const_mgr->GetOperandConstants(op_inst);
+ if (op_constants[0] || op_constants[1]) {
+ bool zero_is_variable = op_constants[0] == nullptr;
+ bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
+ (op_inst->opcode() == SpvOpIAdd);
+ bool swap_operands = !is_add || zero_is_variable;
+ bool negate_const = is_add;
+ const analysis::Constant* c = ConstInput(op_constants);
+ uint32_t const_id = 0;
+ if (negate_const) {
+ const_id = NegateConstant(const_mgr, c);
+ } else {
+ const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
+ : op_inst->GetSingleWordInOperand(0u);
+ }
+
+ // Swap operands if necessary and make the instruction a subtraction.
+ uint32_t op0 =
+ zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
+ uint32_t op1 =
+ zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
+ if (swap_operands) std::swap(op0, op1);
+ inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
+ return true;
+ }
+ }
+
+ return false;
+ };
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Float.
+uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
+ SpvOp opcode,
+ const analysis::Constant* input1,
+ const analysis::Constant* input2) {
+ const analysis::Type* type = input1->type();
+ assert(type->AsFloat());
+ uint32_t width = type->AsFloat()->width();
+ assert(width == 32 || width == 64);
+ std::vector<uint32_t> words;
+#define FOLD_OP(op) \
+ if (width == 64) { \
+ spvutils::FloatProxy<double> val = \
+ input1->GetDouble() op input2->GetDouble(); \
+ double dval = val.getAsFloat(); \
+ if (!IsValidResult(dval)) return 0; \
+ words = val.GetWords(); \
+ } else { \
+ spvutils::FloatProxy<float> val = \
+ input1->GetFloat() op input2->GetFloat(); \
+ float fval = val.getAsFloat(); \
+ if (!IsValidResult(fval)) return 0; \
+ words = val.GetWords(); \
+ }
+ switch (opcode) {
+ case SpvOpFMul:
+ FOLD_OP(*);
+ break;
+ case SpvOpFDiv:
+ FOLD_OP(/);
+ break;
+ case SpvOpFAdd:
+ FOLD_OP(+);
+ break;
+ case SpvOpFSub:
+ FOLD_OP(-);
+ break;
+ default:
+ assert(false && "Unexpected operation");
+ break;
+ }
+#undef FOLD_OP
+ const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
+ return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Integers.
+uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
+ SpvOp opcode, const analysis::Constant* input1,
+ const analysis::Constant* input2) {
+ assert(input1->type()->AsInteger());
+ const analysis::Integer* type = input1->type()->AsInteger();
+ uint32_t width = type->AsInteger()->width();
+ assert(width == 32 || width == 64);
+ std::vector<uint32_t> words;
+#define FOLD_OP(op) \
+ if (width == 64) { \
+ if (type->IsSigned()) { \
+ int64_t val = input1->GetS64() op input2->GetS64(); \
+ words = ExtractInts(static_cast<uint64_t>(val)); \
+ } else { \
+ uint64_t val = input1->GetU64() op input2->GetU64(); \
+ words = ExtractInts(val); \
+ } \
+ } else { \
+ if (type->IsSigned()) { \
+ int32_t val = input1->GetS32() op input2->GetS32(); \
+ words.push_back(static_cast<uint32_t>(val)); \
+ } else { \
+ uint32_t val = input1->GetU32() op input2->GetU32(); \
+ words.push_back(val); \
+ } \
+ }
+ switch (opcode) {
+ case SpvOpIMul:
+ FOLD_OP(*);
+ break;
+ case SpvOpSDiv:
+ case SpvOpUDiv:
+ // To avoid losing precision we won't perform division that would result
+ // in a remainder. Unfortunate code duplication results.
+ if (input2->AsIntConstant()->IsZero()) return 0;
+ if (width == 64) {
+ if (type->IsSigned()) {
+ if (input1->GetS64() % input2->GetS64() != 0) return 0;
+ int64_t val = input1->GetS64() / input2->GetS64();
+ words = ExtractInts(static_cast<uint64_t>(val));
+ } else {
+ if (input1->GetU64() % input2->GetU64() != 0) return 0;
+ uint64_t val = input1->GetU64() / input2->GetU64();
+ words = ExtractInts(val);
+ }
+ } else {
+ if (type->IsSigned()) {
+ if (input1->GetS32() % input2->GetS32() != 0) return 0;
+ int32_t val = input1->GetS32() / input2->GetS32();
+ words.push_back(static_cast<uint32_t>(val));
+ } else {
+ if (input1->GetU32() % input2->GetU32() != 0) return 0;
+ uint32_t val = input1->GetU32() / input2->GetU32();
+ words.push_back(val);
+ }
+ }
+ break;
+ case SpvOpIAdd:
+ FOLD_OP(+);
+ break;
+ case SpvOpISub:
+ FOLD_OP(-);
+ break;
+ default:
+ assert(false && "Unexpected operation");
+ break;
+ }
+#undef FOLD_OP
+ const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
+ return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+}
+
+// Performs |input1| |opcode| |input2| and returns the merged constant result
+// id. Returns 0 if the result is not a valid value. The input types must be
+// Integers, Floats or Vectors of such.
+uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
+ const analysis::Constant* input1,
+ const analysis::Constant* input2) {
+ assert(input1 && input2);
+ assert(input1->type() == input2->type());
+ const analysis::Type* type = input1->type();
+ std::vector<uint32_t> words;
+ if (const analysis::Vector* vector_type = type->AsVector()) {
+ const analysis::Type* ele_type = vector_type->element_type();
+ for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
+ uint32_t id = 0;
+ const analysis::Constant* input1_comp =
+ input1->AsVectorConstant()->GetComponents()[i];
+ const analysis::Constant* input2_comp =
+ input2->AsVectorConstant()->GetComponents()[i];
+ if (ele_type->AsFloat()) {
+ id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
+ input2_comp);
+ } else {
+ assert(ele_type->AsInteger());
+ id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
+ input2_comp);
+ }
+ if (id == 0) return 0;
+ words.push_back(id);
+ }
+ const analysis::Constant* merged_const =
+ const_mgr->GetConstant(type, words);
+ return const_mgr->GetDefiningInstruction(merged_const)->result_id();
+ } else if (type->AsFloat()) {
+ return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
+ } else {
+ assert(type->AsInteger());
+ return PerformIntegerOperation(const_mgr, opcode, input1, input2);
+ }
+}
+
+// Merges consecutive multiplies where each contains one constant operand.
+// Cases:
+// 2 * (x * 2) = x * 4
+// 2 * (2 * x) = x * 4
+// (x * 2) * 2 = x * 4
+// (2 * x) * 2 = x * 4
+FoldingRule MergeMulMulArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ // Determine the constant input and the variable input in |inst|.
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == inst->opcode()) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ bool other_first_is_variable = other_constants[0] == nullptr;
+ uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+ const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ uint32_t non_const_id = other_first_is_variable
+ ? other_inst->GetSingleWordInOperand(0u)
+ : other_inst->GetSingleWordInOperand(1u);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+ {SPV_OPERAND_TYPE_ID, {merged_id}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Merges divides into subsequent multiplies if each instruction contains one
+// constant operand.
+// Cases:
+// 2 * (x / 2) = 4 / x
+// 2 * (2 / x) = x * 1
+// (x / 2) * 2 = x * 1
+// (2 / x) * 2 = 4 / x
+FoldingRule MergeMulDivArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFDiv ||
+ other_inst->opcode() == SpvOpSDiv ||
+ other_inst->opcode() == SpvOpUDiv) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ bool other_first_is_variable = other_constants[0] == nullptr;
+ // If the variable value is the second operand of the divide, multiply
+ // the constants together. Otherwise divide the constants.
+ uint32_t merged_id = PerformOperation(
+ const_mgr,
+ other_first_is_variable ? other_inst->opcode() : inst->opcode(),
+ const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ uint32_t non_const_id = other_first_is_variable
+ ? other_inst->GetSingleWordInOperand(0u)
+ : other_inst->GetSingleWordInOperand(1u);
+
+ // If the variable value is on the second operand of the div, then this
+ // operation is a div. Otherwise it should be a multiply.
+ inst->SetOpcode(other_first_is_variable ? inst->opcode()
+ : other_inst->opcode());
+ if (other_first_is_variable) {
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
+ {SPV_OPERAND_TYPE_ID, {merged_id}}});
+ } else {
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
+ {SPV_OPERAND_TYPE_ID, {non_const_id}}});
+ }
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Merges multiply of constant and negation.
+// Cases:
+// (-x) * 2 = x * -2
+// 2 * (-x) = x * -2
+FoldingRule MergeMulNegateArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFNegate ||
+ other_inst->opcode() == SpvOpSNegate) {
+ uint32_t neg_id = NegateConstant(const_mgr, const_input1);
+
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
+ {SPV_OPERAND_TYPE_ID, {neg_id}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Merges consecutive divides if each instruction contains one constant operand.
+// Cases:
+// 2 / (x / 2) = 4 / x
+// 4 / (2 / x) = 2 * x
+// (4 / x) / 2 = 2 / x
+// (x / 2) / 2 = x / 4
+FoldingRule MergeDivDivArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+ inst->opcode() == SpvOpUDiv);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ bool first_is_variable = constants[0] == nullptr;
+ if (other_inst->opcode() == inst->opcode()) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ bool other_first_is_variable = other_constants[0] == nullptr;
+
+ SpvOp merge_op = inst->opcode();
+ if (other_first_is_variable) {
+ // Constants magnify.
+ merge_op = uses_float ? SpvOpFMul : SpvOpIMul;
+ }
+
+ // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
+ // because it is commutative.
+ if (first_is_variable) std::swap(const_input1, const_input2);
+ uint32_t merged_id =
+ PerformOperation(const_mgr, merge_op, const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ uint32_t non_const_id = other_first_is_variable
+ ? other_inst->GetSingleWordInOperand(0u)
+ : other_inst->GetSingleWordInOperand(1u);
+
+ SpvOp op = inst->opcode();
+ if (!first_is_variable && !other_first_is_variable) {
+ // Effectively div of 1/x, so change to multiply.
+ op = uses_float ? SpvOpFMul : SpvOpIMul;
+ }
+
+ uint32_t op1 = merged_id;
+ uint32_t op2 = non_const_id;
+ if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
+ inst->SetOpcode(op);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Fold multiplies succeeded by divides where each instruction contains a
+// constant operand.
+// Cases:
+// 4 / (x * 2) = 2 / x
+// 4 / (2 * x) = 2 / x
+// (x * 4) / 2 = x * 2
+// (4 * x) / 2 = x * 2
+FoldingRule MergeDivMulArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+ inst->opcode() == SpvOpUDiv);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ bool first_is_variable = constants[0] == nullptr;
+ if (other_inst->opcode() == SpvOpFMul ||
+ other_inst->opcode() == SpvOpIMul) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ bool other_first_is_variable = other_constants[0] == nullptr;
+
+ // This is an x / (*) case. Swap the inputs.
+ if (first_is_variable) std::swap(const_input1, const_input2);
+ uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+ const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ uint32_t non_const_id = other_first_is_variable
+ ? other_inst->GetSingleWordInOperand(0u)
+ : other_inst->GetSingleWordInOperand(1u);
+
+ uint32_t op1 = merged_id;
+ uint32_t op2 = non_const_id;
+ if (first_is_variable) std::swap(op1, op2);
+
+ // Convert to multiply
+ if (first_is_variable) inst->SetOpcode(other_inst->opcode());
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Fold divides of a constant and a negation.
+// Cases:
+// (-x) / 2 = x / -2
+// 2 / (-x) = 2 / -x
+FoldingRule MergeDivNegateArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
+ inst->opcode() == SpvOpUDiv);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ bool first_is_variable = constants[0] == nullptr;
+ if (other_inst->opcode() == SpvOpFNegate ||
+ other_inst->opcode() == SpvOpSNegate) {
+ uint32_t neg_id = NegateConstant(const_mgr, const_input1);
+
+ if (first_is_variable) {
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
+ {SPV_OPERAND_TYPE_ID, {neg_id}}});
+ } else {
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {neg_id}},
+ {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
+ }
+ return true;
+ }
+
+ return false;
+ };
+}
+
+// Folds addition of a constant and a negation.
+// Cases:
+// (-x) + 2 = 2 - x
+// 2 + (-x) = 2 - x
+FoldingRule MergeAddNegateArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpSNegate ||
+ other_inst->opcode() == SpvOpFNegate) {
+ inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
+ uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
+ : inst->GetSingleWordInOperand(1u);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {const_id}},
+ {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
+ return true;
+ }
+ return false;
+ };
+}
+
+// Folds subtraction of a constant and a negation.
+// Cases:
+// (-x) - 2 = -2 - x
+// 2 - (-x) = x + 2
+FoldingRule MergeSubNegateArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpSNegate ||
+ other_inst->opcode() == SpvOpFNegate) {
+ uint32_t op1 = 0;
+ uint32_t op2 = 0;
+ SpvOp opcode = inst->opcode();
+ if (constants[0] != nullptr) {
+ op1 = other_inst->GetSingleWordInOperand(0u);
+ op2 = inst->GetSingleWordInOperand(0u);
+ opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
+ } else {
+ op1 = NegateConstant(const_mgr, const_input1);
+ op2 = other_inst->GetSingleWordInOperand(0u);
+ }
+
+ inst->SetOpcode(opcode);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+ return false;
+ };
+}
+
+// Folds addition of an addition where each operation has a constant operand.
+// Cases:
+// (x + 2) + 2 = x + 4
+// (2 + x) + 2 = x + 4
+// 2 + (x + 2) = x + 4
+// 2 + (2 + x) = x + 4
+FoldingRule MergeAddAddArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFAdd ||
+ other_inst->opcode() == SpvOpIAdd) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ ir::Instruction* non_const_input =
+ NonConstInput(context, other_constants[0], other_inst);
+ uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+ const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {merged_id}}});
+ return true;
+ }
+ return false;
+ };
+}
+
+// Folds addition of a subtraction where each operation has a constant operand.
+// Cases:
+// (x - 2) + 2 = x + 0
+// (2 - x) + 2 = 4 - x
+// 2 + (x - 2) = x + 0
+// 2 + (2 - x) = 4 - x
+FoldingRule MergeAddSubArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFSub ||
+ other_inst->opcode() == SpvOpISub) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ bool first_is_variable = other_constants[0] == nullptr;
+ SpvOp op = inst->opcode();
+ uint32_t op1 = 0;
+ uint32_t op2 = 0;
+ if (first_is_variable) {
+ // Subtract constants. Non-constant operand is first.
+ op1 = other_inst->GetSingleWordInOperand(0u);
+ op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
+ const_input2);
+ } else {
+ // Add constants. Constant operand is first. Change the opcode.
+ op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
+ const_input2);
+ op2 = other_inst->GetSingleWordInOperand(1u);
+ op = other_inst->opcode();
+ }
+ if (op1 == 0 || op2 == 0) return false;
+
+ inst->SetOpcode(op);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+ return false;
+ };
+}
+
+// Folds subtraction of an addition where each operand has a constant operand.
+// Cases:
+// (x + 2) - 2 = x + 0
+// (2 + x) - 2 = x + 0
+// 2 - (x + 2) = 0 - x
+// 2 - (2 + x) = 0 - x
+FoldingRule MergeSubAddArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFAdd ||
+ other_inst->opcode() == SpvOpIAdd) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ ir::Instruction* non_const_input =
+ NonConstInput(context, other_constants[0], other_inst);
+
+ // If the first operand of the sub is not a constant, swap the constants
+ // so the subtraction has the correct operands.
+ if (constants[0] == nullptr) std::swap(const_input1, const_input2);
+ // Subtract the constants.
+ uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
+ const_input1, const_input2);
+ SpvOp op = inst->opcode();
+ uint32_t op1 = 0;
+ uint32_t op2 = 0;
+ if (constants[0] == nullptr) {
+ // Non-constant operand is first. Change the opcode.
+ op1 = non_const_input->result_id();
+ op2 = merged_id;
+ op = other_inst->opcode();
+ } else {
+ // Constant operand is first.
+ op1 = merged_id;
+ op2 = non_const_input->result_id();
+ }
+ if (op1 == 0 || op2 == 0) return false;
+
+ inst->SetOpcode(op);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+ return false;
+ };
+}
+
+// Folds subtraction of a subtraction where each operand has a constant operand.
+// Cases:
+// (x - 2) - 2 = x - 4
+// (2 - x) - 2 = 0 - x
+// 2 - (x - 2) = 4 - x
+// 2 - (2 - x) = x + 0
+FoldingRule MergeSubSubArithmetic() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
+ ir::IRContext* context = inst->context();
+ const analysis::Type* type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ bool uses_float = HasFloatingPoint(type);
+ if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
+
+ uint32_t width = ElementWidth(type);
+ if (width != 32 && width != 64) return false;
+
+ const analysis::Constant* const_input1 = ConstInput(constants);
+ if (!const_input1) return false;
+ ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
+ if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
+ return false;
+
+ if (other_inst->opcode() == SpvOpFSub ||
+ other_inst->opcode() == SpvOpISub) {
+ std::vector<const analysis::Constant*> other_constants =
+ const_mgr->GetOperandConstants(other_inst);
+ const analysis::Constant* const_input2 = ConstInput(other_constants);
+ if (!const_input2) return false;
+
+ ir::Instruction* non_const_input =
+ NonConstInput(context, other_constants[0], other_inst);
+
+ // Merge the constants.
+ uint32_t merged_id = 0;
+ SpvOp merge_op = inst->opcode();
+ if (other_constants[0] == nullptr) {
+ merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
+ } else if (constants[0] == nullptr) {
+ std::swap(const_input1, const_input2);
+ }
+ merged_id =
+ PerformOperation(const_mgr, merge_op, const_input1, const_input2);
+ if (merged_id == 0) return false;
+
+ SpvOp op = inst->opcode();
+ if (constants[0] != nullptr && other_constants[0] != nullptr) {
+ // Change the operation.
+ op = uses_float ? SpvOpFAdd : SpvOpIAdd;
+ }
+
+ uint32_t op1 = 0;
+ uint32_t op2 = 0;
+ if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
+ op1 = merged_id;
+ op2 = non_const_input->result_id();
+ } else {
+ op1 = non_const_input->result_id();
+ op2 = merged_id;
+ }
+
+ inst->SetOpcode(op);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
+ return true;
+ }
+ return false;
+ };
+}
+
FoldingRule IntMultipleBy1() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
continue;
}
const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
- if (int_constant && int_constant->GetU32BitValue() == 1) {
- inst->SetOpcode(SpvOpCopyObject);
- inst->SetInOperands(
- {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
- return true;
+ if (int_constant) {
+ uint32_t width = ElementWidth(int_constant->type());
+ if (width != 32 && width != 64) return false;
+ bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
+ : int_constant->GetU64BitValue() == 1ull;
+ if (is_one) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands(
+ {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
+ return true;
+ }
}
}
return false;
rules_[SpvOpExtInst].push_back(RedundantFMix());
rules_[SpvOpFAdd].push_back(RedundantFAdd());
+ rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
+ rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
+ rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
+
rules_[SpvOpFDiv].push_back(RedundantFDiv());
+ rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
+ rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
+ rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
+ rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
+
rules_[SpvOpFMul].push_back(RedundantFMul());
+ rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
+ rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
+ rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
+
+ rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
+ rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
+ rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
+
rules_[SpvOpFSub].push_back(RedundantFSub());
+ rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
+ rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
+ rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
+
+ rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
+ rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
+ rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpIMul].push_back(IntMultipleBy1());
+ rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
+ rules_[SpvOpIMul].push_back(MergeMulDivArithmetic());
+ rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
+
+ rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
+ rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
+ rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
rules_[SpvOpPhi].push_back(RedundantPhi());
+ rules_[SpvOpSDiv].push_back(MergeDivDivArithmetic());
+ rules_[SpvOpSDiv].push_back(MergeDivMulArithmetic());
+ rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
+
+ rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
+ rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
+ rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
+
rules_[SpvOpSelect].push_back(RedundantSelect());
+
+ rules_[SpvOpUDiv].push_back(MergeDivDivArithmetic());
+ rules_[SpvOpUDiv].push_back(MergeDivMulArithmetic());
+ rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
}
} // namespace opt
#include <gtest/gtest.h>
#include <opt/fold.h>
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
#include "opt/build_module.h"
#include "opt/def_use_manager.h"
#include "opt/ir_context.h"
using namespace spvtools;
using spvtools::opt::analysis::DefUseManager;
+std::string Disassemble(const std::string& original, ir::IRContext* context,
+ uint32_t disassemble_options = 0) {
+ std::vector<uint32_t> optimized_bin;
+ context->module()->ToBinary(&optimized_bin, true);
+ spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+ SpirvTools tools(target_env);
+ std::string optimized_asm;
+ EXPECT_TRUE(
+ tools.Disassemble(optimized_bin, &optimized_asm, disassemble_options))
+ << "Disassembling failed for shader:\n"
+ << original << std::endl;
+ return optimized_asm;
+}
+
+#ifdef SPIRV_EFFCEE
+void Match(const std::string& original, ir::IRContext* context,
+ uint32_t disassemble_options = 0) {
+ std::string disassembly = Disassemble(original, context, disassemble_options);
+ auto match_result = effcee::Match(disassembly, original);
+ EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+ << match_result.message() << "\nChecking result:\n"
+ << disassembly;
+}
+#endif
+
template <class ResultType>
struct InstructionFoldingCase {
InstructionFoldingCase(const std::string& tb, uint32_t id, ResultType result)
%short = OpTypeInt 16 1
%int = OpTypeInt 32 1
%long = OpTypeInt 64 1
-%uint = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
%v2int = OpTypeVector %int 2
%v4int = OpTypeVector %int 4
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
+%v2float = OpTypeVector %float 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
%_ptr_bool = OpTypePointer Function %bool
%_ptr_float = OpTypePointer Function %float
%_ptr_double = OpTypePointer Function %double
+%_ptr_long = OpTypePointer Function %long
+%_ptr_v2int = OpTypePointer Function %v2int
%_ptr_v4float = OpTypePointer Function %v4float
%_ptr_v4double = OpTypePointer Function %v4double
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
+%_ptr_v2float = OpTypePointer Function %v2float
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
+%int_2 = OpConstant %int 2
%int_3 = OpConstant %int 3
+%int_4 = OpConstant %int 4
%int_min = OpConstant %int -2147483648
%int_max = OpConstant %int 2147483647
%long_0 = OpConstant %long 0
+%long_2 = OpConstant %long 2
%long_3 = OpConstant %long 3
%uint_0 = OpConstant %uint 0
+%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
%uint_32 = OpConstant %uint 32
-%uint_max = OpConstant %uint -1
+%uint_max = OpConstant %uint 4294967295
%v2int_undef = OpUndef %v2int
+%v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2
+%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
+%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
+%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%102 = OpConstantComposite %v2int %103 %103
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
+%float_4 = OpConstant %float 4
+%float_0p5 = OpConstant %float 0.5
+%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3
+%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
+%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
%double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_0 = OpConstant %double 0
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 3: Don't fold n / 2.0
- InstructionFoldingCase<uint32_t>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%n = OpVariable %_ptr_float Function\n" +
- "%3 = OpLoad %float %n\n" +
- "%2 = OpFDiv %float %3 %float_2\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, 0),
- // Test case 4: Fold n + 0.0
+ // Test case 3: Fold n + 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 5: Fold 0.0 + n
+ // Test case 4: Fold 0.0 + n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 6: Fold n - 0.0
+ // Test case 5: Fold n - 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 7: Fold n * 1.0
+ // Test case 6: Fold n * 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 8: Fold 1.0 * n
+ // Test case 7: Fold 1.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 9: Fold n / 1.0
+ // Test case 8: Fold n / 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 10: Fold n * 0.0
+ // Test case 9: Fold n * 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
- // Test case 11: Fold 0.0 * n
+ // Test case 10: Fold 0.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
- // Test case 12: Fold 0.0 / n
+ // Test case 11: Fold 0.0 / n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
- // Test case 13: Don't fold mix(a, b, 2.0)
+ // Test case 12: Don't fold mix(a, b, 2.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 14: Fold mix(a, b, 0.0)
+ // Test case 13: Fold mix(a, b, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 15: Fold mix(a, b, 1.0)
+ // Test case 14: Fold mix(a, b, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 3: Don't fold n / 2.0
- InstructionFoldingCase<uint32_t>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%n = OpVariable %_ptr_double Function\n" +
- "%3 = OpLoad %double %n\n" +
- "%2 = OpFDiv %double %3 %double_2\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, 0),
- // Test case 4: Fold n + 0.0
+ // Test case 3: Fold n + 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 5: Fold 0.0 + n
+ // Test case 4: Fold 0.0 + n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 6: Fold n - 0.0
+ // Test case 5: Fold n - 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 7: Fold n * 1.0
+ // Test case 6: Fold n * 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 8: Fold 1.0 * n
+ // Test case 7: Fold 1.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 9: Fold n / 1.0
+ // Test case 8: Fold n / 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 10: Fold n * 0.0
+ // Test case 9: Fold n * 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
- // Test case 11: Fold 0.0 * n
+ // Test case 10: Fold 0.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
- // Test case 12: Fold 0.0 / n
+ // Test case 11: Fold 0.0 / n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
- // Test case 13: Don't fold mix(a, b, 2.0)
+ // Test case 12: Don't fold mix(a, b, 2.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 14: Fold mix(a, b, 0.0)
+ // Test case 13: Fold mix(a, b, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
- // Test case 15: Fold mix(a, b, 1.0)
+ // Test case 14: Fold mix(a, b, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"OpFunctionEnd",
2, 3)
));
-// clang-format on
+#ifdef SPIRV_EFFCEE
+using MatchingInstructionFoldingTest =
+ ::testing::TestWithParam<InstructionFoldingCase<bool>>;
+
+TEST_P(MatchingInstructionFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ std::unique_ptr<ir::Instruction> original_inst(inst->Clone(context.get()));
+ bool succeeded = opt::FoldInstruction(inst);
+ EXPECT_EQ(succeeded, tc.expected_result);
+ if (succeeded) {
+ Match(tc.test_body, context.get());
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: fold consecutive fnegate
+ // -(-x) = x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float:%\\w+]]\n" +
+ "; CHECK: %4 = OpCopyObject [[float]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFNegate %float %2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 1: fold fnegate(fmul with const).
+ // -(x * 2.0) = x * -2.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFMul %float %2 %float_2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 2: fold fnegate(fmul with const).
+ // -(2.0 * x) = x * 2.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFMul %float %float_2 %2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 3: fold fnegate(fdiv with const).
+ // -(x / 2.0) = x * -0.5
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n0p5:%\\w+]] = OpConstant [[float]] -0.5\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n0p5]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %2 %float_2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 4: fold fnegate(fdiv with const).
+ // -(2.0 / x) = -2.0 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFDiv [[float]] [[float_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %float_2 %2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 5: fold fnegate(fadd with const).
+ // -(2.0 + x) = -2.0 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %float_2 %2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 6: fold fnegate(fadd with const).
+ // -(x + 2.0) = -2.0 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %2 %float_2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 7: fold fnegate(fsub with const).
+ // -(2.0 - x) = x - 2.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_2 %2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 8: fold fnegate(fsub with const).
+ // -(x - 2.0) = 2.0 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %2 %float_2\n" +
+ "%4 = OpFNegate %float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 9: fold consecutive snegate
+ // -(-x) = x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int:%\\w+]]\n" +
+ "; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSNegate %int %2\n" +
+ "%4 = OpSNegate %int %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 10: fold consecutive vector negate
+ // -(-x) = x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float:%\\w+]]\n" +
+ "; CHECK: %4 = OpCopyObject [[v2float]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %var\n" +
+ "%3 = OpFNegate %v2float %2\n" +
+ "%4 = OpFNegate %v2float %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 11: fold snegate(iadd with const).
+ // -(2 + x) = -2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIAdd %int %int_2 %2\n" +
+ "%4 = OpSNegate %int %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 12: fold snegate(iadd with const).
+ // -(x + 2) = -2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIAdd %int %2 %int_2\n" +
+ "%4 = OpSNegate %int %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 13: fold snegate(isub with const).
+ // -(2 - x) = x - 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpISub [[int]] [[ld]] [[int_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpISub %int %int_2 %2\n" +
+ "%4 = OpSNegate %int %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 14: fold snegate(isub with const).
+ // -(x - 2) = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpISub [[int]] [[int_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpISub %int %2 %int_2\n" +
+ "%4 = OpSNegate %int %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 15: fold snegate(iadd with const).
+ // -(x + 2) = -2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpIAdd %long %2 %long_2\n" +
+ "%4 = OpSNegate %long %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 16: fold snegate(isub with const).
+ // -(2 - x) = x - 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[ld]] [[long_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpISub %long %long_2 %2\n" +
+ "%4 = OpSNegate %long %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 17: fold snegate(isub with const).
+ // -(x - 2) = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpISub %long %2 %long_2\n" +
+ "%4 = OpSNegate %long %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: scalar reicprocal
+ // x / 0.5 = x * 2.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %3 = OpFMul [[float]] [[ld]] [[float_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %2 %float_0p5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 3, true),
+ // Test case 1: Unfoldable
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_0:%\\w+]] = OpConstant [[float]] 0\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %3 = OpFDiv [[float]] [[ld]] [[float_0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %2 %104\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 3, false),
+ // Test case 2: Vector reciprocal
+ // x / {2.0, 0.5} = x * {0.5, 2.0}
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[float_0p5:%\\w+]] = OpConstant [[float]] 0.5\n" +
+ "; CHECK: [[v2float_0p5_2:%\\w+]] = OpConstantComposite [[v2float]] [[float_0p5]] [[float_2]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" +
+ "; CHECK: %3 = OpFMul [[v2float]] [[ld]] [[v2float_0p5_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %var\n" +
+ "%3 = OpFDiv %v2float %2 %v2float_2_0p5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 3, true),
+ // Test case 3: double reciprocal
+ // x / 2.0 = x * 0.5
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+ "; CHECK: [[double_0p5:%\\w+]] = OpConstant [[double]] 0.5\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" +
+ "; CHECK: %3 = OpFMul [[double]] [[ld]] [[double_0p5]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_double Function\n" +
+ "%2 = OpLoad %double %var\n" +
+ "%3 = OpFDiv %double %2 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 3, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: fold consecutive fmuls
+ // (x * 3.0) * 2.0 = x * 6.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFMul %float %2 %float_3\n" +
+ "%4 = OpFMul %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 1: fold consecutive fmuls
+ // 2.0 * (x * 3.0) = x * 6.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFMul %float %2 %float_3\n" +
+ "%4 = OpFMul %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 2: fold consecutive fmuls
+ // (3.0 * x) * 2.0 = x * 6.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFMul %float %float_3 %2\n" +
+ "%4 = OpFMul %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 3: fold vector fmul
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+ "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" +
+ "; CHECK: [[v2float_6_6:%\\w+]] = OpConstantComposite [[v2float]] [[float_6]] [[float_6]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" +
+ "; CHECK: %4 = OpFMul [[v2float]] [[ld]] [[v2float_6_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %var\n" +
+ "%3 = OpFMul %v2float %2 %v2float_2_3\n" +
+ "%4 = OpFMul %v2float %3 %v2float_3_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 4: fold double fmuls
+ // (x * 3.0) * 2.0 = x * 6.0
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+ "; CHECK: [[double_6:%\\w+]] = OpConstant [[double]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" +
+ "; CHECK: %4 = OpFMul [[double]] [[ld]] [[double_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_double Function\n" +
+ "%2 = OpLoad %double %var\n" +
+ "%3 = OpFMul %double %2 %double_3\n" +
+ "%4 = OpFMul %double %3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 5: fold 32 bit imuls
+ // (x * 3) * 2 = x * 6
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIMul %int %2 %int_3\n" +
+ "%4 = OpIMul %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 6: fold 64 bit imuls
+ // (x * 3) * 2 = x * 6
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
+ "; CHECK: [[long_6:%\\w+]] = OpConstant [[long]] 6\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpIMul %long %2 %long_3\n" +
+ "%4 = OpIMul %long %3 %long_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 7: merge vector integer mults
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" +
+ "; CHECK: [[v2int_6_6:%\\w+]] = OpConstantComposite [[v2int]] [[int_6]] [[int_6]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+ "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_6_6]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %var\n" +
+ "%3 = OpIMul %v2int %2 %v2int_2_3\n" +
+ "%4 = OpIMul %v2int %3 %v2int_3_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 8: merge fmul of fdiv
+ // 2.0 * (2.0 / x) = 4.0 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %float_2 %2\n" +
+ "%4 = OpFMul %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 9: merge fmul of fdiv
+ // (2.0 / x) * 2.0 = 4.0 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %float_2 %2\n" +
+ "%4 = OpFMul %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 10: merge imul of sdiv
+ // 4 * (x / 2) = 2 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %2 %int_2\n" +
+ "%4 = OpIMul %int %int_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 11: merge imul of sdiv
+ // (x / 2) * 4 = 2 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %2 %int_2\n" +
+ "%4 = OpIMul %int %3 %int_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 12: merge imul of udiv
+ // 4 * (x / 2) = 2 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
+ "; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
+ "; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_uint Function\n" +
+ "%2 = OpLoad %uint %var\n" +
+ "%3 = OpUDiv %uint %2 %uint_2\n" +
+ "%4 = OpIMul %uint %uint_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 13: merge imul of udiv
+ // (x / 2) * 4 = 2 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
+ "; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
+ "; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_uint Function\n" +
+ "%2 = OpLoad %uint %var\n" +
+ "%3 = OpUDiv %uint %2 %uint_2\n" +
+ "%4 = OpIMul %uint %3 %uint_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 14: Don't fold if would have remainder
+ // (x / 3) * 4
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_uint Function\n" +
+ "%2 = OpLoad %uint %var\n" +
+ "%3 = OpUDiv %uint %2 %uint_3\n" +
+ "%4 = OpIMul %uint %3 %uint_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, false),
+ // Test case 15: merge vector imul of sdiv
+ // (x / {2,2}) * {4,4} = x * {2,2}
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[v2int_2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int_2]] [[int_2]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+ "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_2_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %var\n" +
+ "%3 = OpSDiv %v2int %2 %v2int_2_2\n" +
+ "%4 = OpIMul %v2int %3 %v2int_4_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 15: merge vector imul of snegate
+ // (-x) * {2,2} = x * {-2,-2}
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+ "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %var\n" +
+ "%3 = OpSNegate %v2int %2\n" +
+ "%4 = OpIMul %v2int %3 %v2int_2_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 15: merge vector imul of snegate
+ // {2,2} * (-x) = x * {-2,-2}
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
+ "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %var\n" +
+ "%3 = OpSNegate %v2int %2\n" +
+ "%4 = OpIMul %v2int %v2int_2_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: merge consecutive fdiv
+ // 4.0 / (2.0 / x) = 2.0 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFMul [[float]] [[float_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %float_2 %2\n" +
+ "%4 = OpFDiv %float %float_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 1: merge consecutive fdiv
+ // 4.0 / (x / 2.0) = 8.0 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_8:%\\w+]] = OpConstant [[float]] 8\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFDiv [[float]] [[float_8]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %2 %float_2\n" +
+ "%4 = OpFDiv %float %float_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 2: merge consecutive fdiv
+ // (4.0 / x) / 2.0 = 2.0 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFDiv [[float]] [[float_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %float_4 %2\n" +
+ "%4 = OpFDiv %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 3: merge consecutive sdiv
+ // 4 / (2 / x) = 2 * x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[int_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %int_2 %2\n" +
+ "%4 = OpSDiv %int %int_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 4: merge consecutive sdiv
+ // 4 / (x / 2) = 8 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_8:%\\w+]] = OpConstant [[int]] 8\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[int_8]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %2 %int_2\n" +
+ "%4 = OpSDiv %int %int_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 5: merge consecutive sdiv
+ // (4 / x) / 2 = 2 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %int_4 %2\n" +
+ "%4 = OpSDiv %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 6: merge consecutive sdiv
+ // (x / 4) / 2 = x / 8
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_8:%\\w+]] = OpConstant [[int]] 8\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_8]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSDiv %int %2 %int_4\n" +
+ "%4 = OpSDiv %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 7: merge sdiv of imul
+ // 4 / (2 * x) = 2 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIMul %int %int_2 %2\n" +
+ "%4 = OpSDiv %int %int_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 8: merge sdiv of imul
+ // 4 / (x * 2) = 2 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[int_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIMul %int %2 %int_2\n" +
+ "%4 = OpSDiv %int %int_4 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 9: merge sdiv of imul
+ // (4 * x) / 2 = x * 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIMul %int %int_4 %2\n" +
+ "%4 = OpSDiv %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 10: merge sdiv of imul
+ // (x * 4) / 2 = x * 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpIMul %int %2 %int_4\n" +
+ "%4 = OpSDiv %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 11: merge sdiv of snegate
+ // (-x) / 2 = x / -2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_n2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSNegate %int %2\n" +
+ "%4 = OpSDiv %int %3 %int_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 12: merge sdiv of snegate
+ // 2 / (-x) = -2 / x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: OpConstant [[int]] -2147483648\n" +
+ "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
+ "; CHECK: %4 = OpSDiv [[int]] [[int_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_int Function\n" +
+ "%2 = OpLoad %int %var\n" +
+ "%3 = OpSNegate %int %2\n" +
+ "%4 = OpSDiv %int %int_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: merge add of negate
+ // (-x) + 2 = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFNegate %float %2\n" +
+ "%4 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 1: merge add of negate
+ // 2 + (-x) = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpSNegate %float %2\n" +
+ "%4 = OpIAdd %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 2: merge add of negate
+ // (-x) + 2 = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpSNegate %long %2\n" +
+ "%4 = OpIAdd %long %3 %long_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 3: merge add of negate
+ // 2 + (-x) = 2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpSNegate %long %2\n" +
+ "%4 = OpIAdd %long %long_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 4: merge add of subtract
+ // (x - 1) + 2 = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %2 %float_1\n" +
+ "%4 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 5: merge add of subtract
+ // (1 - x) + 2 = 3 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_1 %2\n" +
+ "%4 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 6: merge add of subtract
+ // 2 + (x - 1) = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %2 %float_1\n" +
+ "%4 = OpFAdd %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 7: merge add of subtract
+ // 2 + (1 - x) = 3 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_1 %2\n" +
+ "%4 = OpFAdd %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 8: merge add of add
+ // (x + 1) + 2 = x + 3
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %2 %float_1\n" +
+ "%4 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 9: merge add of add
+ // (1 + x) + 2 = 3 + x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %float_1 %2\n" +
+ "%4 = OpFAdd %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 10: merge add of add
+ // 2 + (x + 1) = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %2 %float_1\n" +
+ "%4 = OpFAdd %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 11: merge add of add
+ // 2 + (1 + x) = 3 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %float_1 %2\n" +
+ "%4 = OpFAdd %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true)
+));
+
+INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: merge sub of negate
+ // (-x) - 2 = -2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFNegate %float %2\n" +
+ "%4 = OpFSub %float %3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 1: merge sub of negate
+ // 2 - (-x) = x + 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFNegate %float %2\n" +
+ "%4 = OpFSub %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 2: merge sub of negate
+ // (-x) - 2 = -2 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpSNegate %long %2\n" +
+ "%4 = OpISub %long %3 %long_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 3: merge sub of negate
+ // 2 - (-x) = x + 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
+ "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
+ "; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_long Function\n" +
+ "%2 = OpLoad %long %var\n" +
+ "%3 = OpSNegate %long %2\n" +
+ "%4 = OpISub %long %long_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 4: merge add of subtract
+ // (x + 2) - 1 = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %2 %float_2\n" +
+ "%4 = OpFSub %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 5: merge add of subtract
+ // (2 + x) - 1 = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %float_2 %2\n" +
+ "%4 = OpFSub %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 6: merge add of subtract
+ // 2 - (x + 1) = 1 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %2 %float_1\n" +
+ "%4 = OpFSub %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 7: merge add of subtract
+ // 2 - (1 + x) = 1 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFAdd %float %float_1 %2\n" +
+ "%4 = OpFSub %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 8: merge subtract of subtract
+ // (x - 2) - 1 = x - 3
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_3]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %2 %float_2\n" +
+ "%4 = OpFSub %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 9: merge subtract of subtract
+ // (2 - x) - 1 = 1 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_2 %2\n" +
+ "%4 = OpFSub %float %3 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 10: merge subtract of subtract
+ // 2 - (x - 1) = 3 - x
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %2 %float_1\n" +
+ "%4 = OpFSub %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 11: merge subtract of subtract
+ // 1 - (2 - x) = x + (-1)
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_n1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_2 %2\n" +
+ "%4 = OpFSub %float %float_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true),
+ // Test case 12: merge subtract of subtract
+ // 2 - (1 - x) = x + 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
+ "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
+ "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFSub %float %float_1 %2\n" +
+ "%4 = OpFSub %float %float_2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, true)
+));
+#endif
} // anonymous namespace