const analysis::Constant* negated_const =
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
- } else {
+ } else if (constants[1]->AsFloatConstant()) {
id = Reciprocal(const_mgr, constants[1]);
if (id == 0) return false;
+ } else {
+ // Don't fold a null constant.
+ return false;
}
inst->SetOpcode(SpvOpFMul);
inst->SetInOperands(
};
}
+// Returns true if |c| has a zero element.
+bool HasZero(const analysis::Constant* c) {
+ if (c->AsNullConstant()) {
+ return true;
+ }
+ if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
+ for (auto& comp : vec_const->GetComponents())
+ if (HasZero(comp)) return true;
+ } else {
+ assert(c->AsScalarConstant());
+ return c->AsScalarConstant()->IsZero();
+ }
+
+ return false;
+}
+
// Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be
// Float.
FOLD_OP(*);
break;
case SpvOpFDiv:
+ if (HasZero(input2)) return 0;
FOLD_OP(/);
break;
case SpvOpFAdd:
const analysis::Type* ele_type = vector_type->element_type();
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
uint32_t id = 0;
- const analysis::Constant* input1_comp =
- input1->AsVectorConstant()->GetComponents()[i];
- const analysis::Constant* input2_comp =
- input2->AsVectorConstant()->GetComponents()[i];
+
+ const analysis::Constant* input1_comp = nullptr;
+ if (const analysis::VectorConstant* input1_vector =
+ input1->AsVectorConstant()) {
+ input1_comp = input1_vector->GetComponents()[i];
+ } else {
+ assert(input1->AsNullConstant());
+ input1_comp = const_mgr->GetConstant(ele_type, {});
+ }
+
+ const analysis::Constant* input2_comp = nullptr;
+ if (const analysis::VectorConstant* input2_vector =
+ input2->AsVectorConstant()) {
+ input2_comp = input2_vector->GetComponents()[i];
+ } else {
+ assert(input2->AsNullConstant());
+ input2_comp = const_mgr->GetConstant(ele_type, {});
+ }
+
if (ele_type->AsFloat()) {
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
input2_comp);
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
- if (!const_input2) return false;
+ if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
// If the variable value is the second operand of the divide, multiply
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
- if (!const_input1) return false;
+ if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
- if (!const_input2) return false;
+ if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
- if (!const_input1) return false;
+ if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
return FloatConstantKind::Unknown;
}
- if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
+ assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
+
+ if (constant->AsNullConstant()) {
+ return FloatConstantKind::Zero;
+ } else if (const analysis::VectorConstant* vc =
+ constant->AsVectorConstant()) {
const std::vector<const analysis::Constant*>& components =
vc->GetComponents();
assert(!components.empty());
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
+%v2float_null = OpConstantNull %v2float
%double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_0 = OpConstant %double 0
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 2, 4)
+ 2, 4),
+ // Test case 15: Fold vector fadd with null
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %a\n" +
+ "%3 = OpFAdd %v2float %2 %v2float_null\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, 2),
+ // Test case 16: Fold vector fadd with null
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %a\n" +
+ "%3 = OpFAdd %v2float %v2float_null %2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, 2),
+ // Test case 15: Fold vector fsub with null
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%a = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %a\n" +
+ "%3 = OpFSub %v2float %2 %v2float_null\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, 2)
));
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
"%3 = OpFDiv %double %2 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
- 3, true)
+ 3, true),
+ // Test case 4: don't fold x / 0.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %v2float %var\n" +
+ "%3 = OpFDiv %v2float %2 %v2float_null\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 3, false)
));
INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
"%4 = OpSDiv %int %int_2 %3\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
- 4, true)
+ 4, true),
+ // Test case 13: Don't merge
+ // (x / {null}) / {null}
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%var = OpVariable %_ptr_v2float Function\n" +
+ "%2 = OpLoad %float %var\n" +
+ "%3 = OpFDiv %float %2 %v2float_null\n" +
+ "%4 = OpFDiv %float %3 %v2float_null\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 4, false)
));
INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,