Adds the floating rules for FAdd, FDiv, FMul, and FSub.
Contributes to #1164.
namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
+// Returns a vector that contains the two 32-bit integers that result from
+// splitting |a| in two. The first entry in vector are the low order bit if
+// |a|.
+inline std::vector<uint32_t> ExtractInts(uint64_t a) {
+ std::vector<uint32_t> result;
+ result.push_back(static_cast<uint32_t>(a));
+ result.push_back(static_cast<uint32_t>(a >> 32));
+ return result;
+}
+
+// Returns true if we are allowed to fold or otherwise manipulate the
+// instruction that defines |id| in the given context.
+bool CanFoldFloatingPoint(ir::IRContext* context, uint32_t id) {
+ // TODO: Add the rules for kernels. For now it will be pessimistic.
+ if (!context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
+ return false;
+ }
+
+ bool is_nocontract = false;
+ context->get_decoration_mgr()->WhileEachDecoration(
+ id, SpvDecorationNoContraction, [&is_nocontract](const ir::Instruction&) {
+ is_nocontract = true;
+ return false;
+ });
+ return !is_nocontract;
+}
+
+// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
- // Folds an OpcompositeExtract where input is a composite constant.
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
- const analysis::NullConstant null_const(
- type_mgr->GetType(inst->type_id()));
- const analysis::Constant* real_const =
- const_mgr->FindConstant(&null_const);
- if (real_const == nullptr) {
- ir::Instruction* const_inst =
- const_mgr->GetDefiningInstruction(&null_const);
- real_const = const_mgr->GetConstantFromInst(const_inst);
- }
- return real_const;
+ return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
}
auto cc = c->AsCompositeConstant();
return const_mgr->GetConstant(new_type, ids);
};
}
+
+// The interface for a function that returns the result of applying a scalar
+// floating-point binary operation on |a| and |b|. The type of the return value
+// will be |type|. The input constants must also be of type |type|.
+using FloatScalarFoldingRule = std::function<const analysis::FloatConstant*(
+ const analysis::Float* type, const analysis::Constant* a,
+ const analysis::Constant* b, analysis::ConstantManager*)>;
+
+// Returns an std::vector containing the elements of |constant|. The type of
+// |constant| must be |Vector|.
+std::vector<const analysis::Constant*> GetVectorComponents(
+ const analysis::Constant* constant, analysis::ConstantManager* const_mgr) {
+ std::vector<const analysis::Constant*> components;
+ const analysis::VectorConstant* a = constant->AsVectorConstant();
+ const analysis::Vector* vector_type = constant->type()->AsVector();
+ assert(vector_type != nullptr);
+ if (a != nullptr) {
+ for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
+ components.push_back(a->GetComponents()[i]);
+ }
+ } else {
+ const analysis::Type* element_type = vector_type->element_type();
+ const analysis::Constant* element_null_const =
+ const_mgr->GetConstant(element_type, {});
+ for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
+ components.push_back(element_null_const);
+ }
+ }
+ return components;
+}
+
+// Returns a |ConstantFoldingRule| that folds floating point scalars using
+// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
+// elements of the vector. The |ConstantFoldingRule| that is returned assumes
+// that |constants| contains 2 entries. If they are not |nullptr|, then their
+// type is either |Float| or a |Vector| whose element type is |Float|.
+ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) {
+ return [scalar_rule](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants)
+ -> const analysis::Constant* {
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
+ const analysis::Vector* vector_type = result_type->AsVector();
+ const analysis::Float* float_type = nullptr;
+
+ if (!CanFoldFloatingPoint(context, inst->result_id())) {
+ return nullptr;
+ }
+
+ if (constants[0] == nullptr || constants[1] == nullptr) {
+ return nullptr;
+ }
+
+ if (vector_type != nullptr) {
+ std::vector<const analysis::Constant*> a_componenets;
+ std::vector<const analysis::Constant*> b_componenets;
+ std::vector<const analysis::FloatConstant*> results_componenets;
+
+ float_type = vector_type->element_type()->AsFloat();
+ a_componenets = GetVectorComponents(constants[0], const_mgr);
+ b_componenets = GetVectorComponents(constants[1], const_mgr);
+
+ // Fold each component of the vector.
+ for (uint32_t i = 0; i < a_componenets.size(); ++i) {
+ results_componenets.push_back(scalar_rule(float_type, a_componenets[i],
+ b_componenets[i], const_mgr));
+ if (results_componenets[i] == nullptr) {
+ return nullptr;
+ }
+ }
+
+ // Build the constant object and return it.
+ std::vector<uint32_t> ids;
+ for (const analysis::FloatConstant* member : results_componenets) {
+ ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
+ }
+ return const_mgr->GetConstant(vector_type, ids);
+ } else {
+ float_type = result_type->AsFloat();
+ return scalar_rule(float_type, constants[0], constants[1], const_mgr);
+ }
+ };
+}
+
+// Returns the floating point value of |c|. The constant |c| must have type
+// |Float|, and width |32|.
+float GetFloatFromConst(const analysis::Constant* c) {
+ assert(c->type()->AsFloat() != nullptr &&
+ c->type()->AsFloat()->width() == 32);
+ const analysis::FloatConstant* fc = c->AsFloatConstant();
+ if (fc) {
+ return fc->GetFloatValue();
+ } else {
+ assert(c->AsNullConstant() && "c must be a float point constant.");
+ return 0.0f;
+ }
+}
+
+// Returns the double value of |c|. The constant |c| must have type
+// |Float|, and width |64|.
+double GetDoubleFromConst(const analysis::Constant* c) {
+ assert(c->type()->AsFloat() != nullptr &&
+ c->type()->AsFloat()->width() == 64);
+ const analysis::FloatConstant* fc = c->AsFloatConstant();
+ if (fc) {
+ return fc->GetDoubleValue();
+ } else {
+ assert(c->AsNullConstant() && "c must be a float point constant.");
+ return 0.0;
+ }
+}
+
+// This macro defines a |FloatScalarFoldingRule| that applies |op|. The
+// operator |op| must work for both float and double, and use syntax "f1 op f2".
+#define FOLD_OP(op) \
+ [](const analysis::Float* type, const analysis::Constant* a, \
+ const analysis::Constant* b, \
+ analysis::ConstantManager* const_mgr) -> const analysis::FloatConstant* { \
+ assert(type != nullptr && a != nullptr && b != nullptr); \
+ if (type->width() == 32) { \
+ float fa = GetFloatFromConst(a); \
+ float fb = GetFloatFromConst(b); \
+ spvutils::FloatProxy<float> result(fa op fb); \
+ std::vector<uint32_t> words = {result.data()}; \
+ return const_mgr->GetConstant(type, words)->AsFloatConstant(); \
+ } else if (type->width() == 64) { \
+ double fa = GetDoubleFromConst(a); \
+ double fb = GetDoubleFromConst(b); \
+ spvutils::FloatProxy<double> result(fa op fb); \
+ std::vector<uint32_t> words(ExtractInts(result.data())); \
+ return const_mgr->GetConstant(type, words)->AsFloatConstant(); \
+ } \
+ return nullptr; \
+ }
+
+// Define the folding rules for subtraction, addition, multiplication, and
+// division for floating point values.
+ConstantFoldingRule FoldFSub() { return FoldFloatingPointOp(FOLD_OP(-)); }
+ConstantFoldingRule FoldFAdd() { return FoldFloatingPointOp(FOLD_OP(+)); }
+ConstantFoldingRule FoldFMul() { return FoldFloatingPointOp(FOLD_OP(*)); }
+ConstantFoldingRule FoldFDiv() { return FoldFloatingPointOp(FOLD_OP(/)); }
} // namespace
spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
// Take that into consideration.
rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
+
rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
+
+ rules_[SpvOpFAdd].push_back(FoldFAdd());
+ rules_[SpvOpFDiv].push_back(FoldFDiv());
+ rules_[SpvOpFMul].push_back(FoldFMul());
+ rules_[SpvOpFSub].push_back(FoldFSub());
}
} // namespace opt
} // namespace spvtools
#include "module.h"
#include "type_manager.h"
#include "types.h"
+#include "util/hex_float.h"
namespace spvtools {
namespace opt {
std::unique_ptr<Constant> Copy() const override {
return std::unique_ptr<Constant>(CopyFloatConstant().release());
}
+
+ // Returns the float value of |this|. The type of |this| must be |Float| with
+ // width of 32.
+ float GetFloatValue() const {
+ assert(type()->AsFloat()->width() == 32 &&
+ "Not a 32-bit floating point value.");
+ spvutils::FloatProxy<float> a(words()[0]);
+ return a.getAsFloat();
+ }
+
+ // Returns the double value of |this|. The type of |this| must be |Float|
+ // with width of 64.
+ double GetDoubleValue() const {
+ assert(type()->AsFloat()->width() == 64 &&
+ "Not a 32-bit floating point value.");
+ uint64_t combined_words = words()[1];
+ combined_words = combined_words << 32;
+ combined_words |= words()[0];
+ spvutils::FloatProxy<double> a(combined_words);
+ return a.getAsFloat();
+ }
};
// Bool type constant.
return std::unique_ptr<Constant>(CopyVectorConstant().release());
}
- const Type* component_type() { return component_type_; }
+ const Type* component_type() const { return component_type_; }
private:
const Type* component_type_;
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ if (!inst->IsFoldable() &&
+ !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
+ return nullptr;
+ }
// Collect the values of the constant parameters.
std::vector<const analysis::Constant*> constants;
bool missing_constants = false;
if (!const_op) {
constants.push_back(nullptr);
missing_constants = true;
- return;
+ } else {
+ constants.push_back(const_op);
}
- constants.push_back(const_op);
});
if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
return const_mgr->GetDefiningInstruction(result_const);
}
-
return nullptr;
}
bool FoldInstruction(ir::Instruction* inst) {
bool modified = false;
ir::Instruction* folded_inst(inst);
- while (FoldInstructionInternal(&*folded_inst)) {
+ while (folded_inst->opcode() != SpvOpCopyObject &&
+ FoldInstructionInternal(&*folded_inst)) {
modified = true;
}
return modified;
for (size_t i = 0; i < work_list.size(); ++i) {
ir::Instruction* inst = work_list[i];
in_work_list.erase(inst);
- if (FoldInstruction(inst)) {
+ if (inst->opcode() == SpvOpCopyObject || FoldInstruction(inst)) {
modified = true;
context()->AnalyzeUses(inst);
get_def_use_mgr()->ForEachUser(
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%bool = OpTypeBool
+%float16 = OpTypeFloat 16
+%float = OpTypeFloat 32
+%double = OpTypeFloat 64
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%102 = OpConstantComposite %v2int %103 %103
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
+%float16_0 = OpConstant %float16 0
+%float16_1 = OpConstant %float16 1
+%float16_2 = OpConstant %float16 2
+%float_n1 = OpConstant %float -1
+%float_0 = OpConstant %float 0
+%float_1 = OpConstant %float 1
+%float_2 = OpConstant %float 2
+%float_3 = OpConstant %float 3
+%double_n1 = OpConstant %double -1
+%double_0 = OpConstant %double 0
+%double_1 = OpConstant %double 1
+%double_2 = OpConstant %double 2
+%double_3 = OpConstant %double 3
)";
return header;
));
// clang-format on
+using FloatInstructionFoldingTest =
+ ::testing::TestWithParam<InstructionFoldingCase<float>>;
+
+TEST_P(FloatInstructionFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ bool succeeded = opt::FoldInstruction(inst);
+
+ // Make sure the instruction folded as expected.
+ EXPECT_TRUE(succeeded);
+ if (inst != nullptr) {
+ EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+ inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+ EXPECT_EQ(inst->opcode(), SpvOpConstant);
+ opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+ const opt::analysis::FloatConstant* result =
+ const_mrg->GetConstantFromInst(inst)->AsFloatConstant();
+ EXPECT_NE(result, nullptr);
+ if (result != nullptr) {
+ EXPECT_EQ(result->GetFloatValue(), tc.expected_result);
+ }
+ }
+}
+
+// Not testing NaNs because there are no expectations concerning NaNs according
+// to the "Precision and Operation of SPIR-V Instructions" section of the Vulkan
+// specification.
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: Fold 2.0 - 1.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFSub %float %float_2 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 1.0),
+ // Test case 1: Fold 2.0 + 1.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFAdd %float %float_2 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3.0),
+ // Test case 2: Fold 3.0 * 2.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFMul %float %float_3 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 6.0),
+ // Test case 3: Fold 1.0 / 2.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %float %float_1 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.5),
+ // Test case 4: Fold 1.0 / 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %float %float_1 %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, std::numeric_limits<float>::infinity()),
+ // Test case 4: Fold -1.0 / 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %float %float_n1 %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, -std::numeric_limits<float>::infinity())
+));
+// clang-format on
+
+using DoubleInstructionFoldingTest =
+ ::testing::TestWithParam<InstructionFoldingCase<double>>;
+
+TEST_P(DoubleInstructionFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ bool succeeded = opt::FoldInstruction(inst);
+
+ // Make sure the instruction folded as expected.
+ EXPECT_TRUE(succeeded);
+ if (inst != nullptr) {
+ EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+ inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+ EXPECT_EQ(inst->opcode(), SpvOpConstant);
+ opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+ const opt::analysis::FloatConstant* result =
+ const_mrg->GetConstantFromInst(inst)->AsFloatConstant();
+ EXPECT_NE(result, nullptr);
+ if (result != nullptr) {
+ EXPECT_EQ(result->GetDoubleValue(), tc.expected_result);
+ }
+ }
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: Fold 2.0 - 1.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFSub %double %double_2 %double_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 1.0),
+ // Test case 1: Fold 2.0 + 1.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFAdd %double %double_2 %double_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3.0),
+ // Test case 2: Fold 3.0 * 2.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFMul %double %double_3 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 6.0),
+ // Test case 3: Fold 1.0 / 2.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %double %double_1 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.5),
+ // Test case 4: Fold 1.0 / 0.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %double %double_1 %double_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, std::numeric_limits<double>::infinity()),
+ // Test case 4: Fold -1.0 / 0.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpFDiv %double %double_n1 %double_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, -std::numeric_limits<double>::infinity())
+));
+// clang-format on
template <class ResultType>
struct InstructionFoldingCaseWithMap {
InstructionFoldingCaseWithMap(const std::string& tb, uint32_t id,