Fold binary floating point operators.
authorSteven Perron <stevenperron@google.com>
Fri, 9 Feb 2018 18:37:26 +0000 (13:37 -0500)
committerSteven Perron <stevenperron@google.com>
Wed, 14 Feb 2018 20:48:15 +0000 (15:48 -0500)
Adds the floating rules for FAdd, FDiv, FMul, and FSub.

Contributes to #1164.

source/opt/const_folding_rules.cpp
source/opt/constants.h
source/opt/fold.cpp
source/opt/simplification_pass.cpp
test/opt/fold_test.cpp

index f4492db..0a08bce 100644 (file)
@@ -20,8 +20,35 @@ namespace opt {
 namespace {
 const uint32_t kExtractCompositeIdInIdx = 0;
 
+// Returns a vector that contains the two 32-bit integers that result from
+// splitting |a| in two.  The first entry in vector are the low order bit if
+// |a|.
+inline std::vector<uint32_t> ExtractInts(uint64_t a) {
+  std::vector<uint32_t> result;
+  result.push_back(static_cast<uint32_t>(a));
+  result.push_back(static_cast<uint32_t>(a >> 32));
+  return result;
+}
+
+// Returns true if we are allowed to fold or otherwise manipulate the
+// instruction that defines |id| in the given context.
+bool CanFoldFloatingPoint(ir::IRContext* context, uint32_t id) {
+  // TODO: Add the rules for kernels.  For now it will be pessimistic.
+  if (!context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
+    return false;
+  }
+
+  bool is_nocontract = false;
+  context->get_decoration_mgr()->WhileEachDecoration(
+      id, SpvDecorationNoContraction, [&is_nocontract](const ir::Instruction&) {
+        is_nocontract = true;
+        return false;
+      });
+  return !is_nocontract;
+}
+
+// Folds an OpcompositeExtract where input is a composite constant.
 ConstantFoldingRule FoldExtractWithConstants() {
-  // Folds an OpcompositeExtract where input is a composite constant.
   return [](ir::Instruction* inst,
             const std::vector<const analysis::Constant*>& constants)
              -> const analysis::Constant* {
@@ -37,16 +64,7 @@ ConstantFoldingRule FoldExtractWithConstants() {
         ir::IRContext* context = inst->context();
         analysis::ConstantManager* const_mgr = context->get_constant_mgr();
         analysis::TypeManager* type_mgr = context->get_type_mgr();
-        const analysis::NullConstant null_const(
-            type_mgr->GetType(inst->type_id()));
-        const analysis::Constant* real_const =
-            const_mgr->FindConstant(&null_const);
-        if (real_const == nullptr) {
-          ir::Instruction* const_inst =
-              const_mgr->GetDefiningInstruction(&null_const);
-          real_const = const_mgr->GetConstantFromInst(const_inst);
-        }
-        return real_const;
+        return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
       }
 
       auto cc = c->AsCompositeConstant();
@@ -83,6 +101,149 @@ ConstantFoldingRule FoldCompositeWithConstants() {
     return const_mgr->GetConstant(new_type, ids);
   };
 }
+
+// The interface for a function that returns the result of applying a scalar
+// floating-point binary operation on |a| and |b|.  The type of the return value
+// will be |type|.  The input constants must also be of type |type|.
+using FloatScalarFoldingRule = std::function<const analysis::FloatConstant*(
+    const analysis::Float* type, const analysis::Constant* a,
+    const analysis::Constant* b, analysis::ConstantManager*)>;
+
+// Returns an std::vector containing the elements of |constant|.  The type of
+// |constant| must be |Vector|.
+std::vector<const analysis::Constant*> GetVectorComponents(
+    const analysis::Constant* constant, analysis::ConstantManager* const_mgr) {
+  std::vector<const analysis::Constant*> components;
+  const analysis::VectorConstant* a = constant->AsVectorConstant();
+  const analysis::Vector* vector_type = constant->type()->AsVector();
+  assert(vector_type != nullptr);
+  if (a != nullptr) {
+    for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
+      components.push_back(a->GetComponents()[i]);
+    }
+  } else {
+    const analysis::Type* element_type = vector_type->element_type();
+    const analysis::Constant* element_null_const =
+        const_mgr->GetConstant(element_type, {});
+    for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
+      components.push_back(element_null_const);
+    }
+  }
+  return components;
+}
+
+// Returns a |ConstantFoldingRule| that folds floating point scalars using
+// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
+// elements of the vector.  The |ConstantFoldingRule| that is returned assumes
+// that |constants| contains 2 entries.  If they are not |nullptr|, then their
+// type is either |Float| or a |Vector| whose element type is |Float|.
+ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) {
+  return [scalar_rule](ir::Instruction* inst,
+                       const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    analysis::TypeManager* type_mgr = context->get_type_mgr();
+    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
+    const analysis::Vector* vector_type = result_type->AsVector();
+    const analysis::Float* float_type = nullptr;
+
+    if (!CanFoldFloatingPoint(context, inst->result_id())) {
+      return nullptr;
+    }
+
+    if (constants[0] == nullptr || constants[1] == nullptr) {
+      return nullptr;
+    }
+
+    if (vector_type != nullptr) {
+      std::vector<const analysis::Constant*> a_componenets;
+      std::vector<const analysis::Constant*> b_componenets;
+      std::vector<const analysis::FloatConstant*> results_componenets;
+
+      float_type = vector_type->element_type()->AsFloat();
+      a_componenets = GetVectorComponents(constants[0], const_mgr);
+      b_componenets = GetVectorComponents(constants[1], const_mgr);
+
+      // Fold each component of the vector.
+      for (uint32_t i = 0; i < a_componenets.size(); ++i) {
+        results_componenets.push_back(scalar_rule(float_type, a_componenets[i],
+                                                  b_componenets[i], const_mgr));
+        if (results_componenets[i] == nullptr) {
+          return nullptr;
+        }
+      }
+
+      // Build the constant object and return it.
+      std::vector<uint32_t> ids;
+      for (const analysis::FloatConstant* member : results_componenets) {
+        ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
+      }
+      return const_mgr->GetConstant(vector_type, ids);
+    } else {
+      float_type = result_type->AsFloat();
+      return scalar_rule(float_type, constants[0], constants[1], const_mgr);
+    }
+  };
+}
+
+// Returns the floating point value of |c|.  The constant |c| must have type
+// |Float|, and width |32|.
+float GetFloatFromConst(const analysis::Constant* c) {
+  assert(c->type()->AsFloat() != nullptr &&
+         c->type()->AsFloat()->width() == 32);
+  const analysis::FloatConstant* fc = c->AsFloatConstant();
+  if (fc) {
+    return fc->GetFloatValue();
+  } else {
+    assert(c->AsNullConstant() && "c must be a float point constant.");
+    return 0.0f;
+  }
+}
+
+// Returns the double value of |c|.  The constant |c| must have type
+// |Float|, and width |64|.
+double GetDoubleFromConst(const analysis::Constant* c) {
+  assert(c->type()->AsFloat() != nullptr &&
+         c->type()->AsFloat()->width() == 64);
+  const analysis::FloatConstant* fc = c->AsFloatConstant();
+  if (fc) {
+    return fc->GetDoubleValue();
+  } else {
+    assert(c->AsNullConstant() && "c must be a float point constant.");
+    return 0.0;
+  }
+}
+
+// This macro defines a |FloatScalarFoldingRule| that applies |op|.  The
+// operator |op| must work for both float and double, and use syntax "f1 op f2".
+#define FOLD_OP(op)                                                            \
+  [](const analysis::Float* type, const analysis::Constant* a,                 \
+     const analysis::Constant* b,                                              \
+     analysis::ConstantManager* const_mgr) -> const analysis::FloatConstant* { \
+    assert(type != nullptr && a != nullptr && b != nullptr);                   \
+    if (type->width() == 32) {                                                 \
+      float fa = GetFloatFromConst(a);                                         \
+      float fb = GetFloatFromConst(b);                                         \
+      spvutils::FloatProxy<float> result(fa op fb);                            \
+      std::vector<uint32_t> words = {result.data()};                           \
+      return const_mgr->GetConstant(type, words)->AsFloatConstant();           \
+    } else if (type->width() == 64) {                                          \
+      double fa = GetDoubleFromConst(a);                                       \
+      double fb = GetDoubleFromConst(b);                                       \
+      spvutils::FloatProxy<double> result(fa op fb);                           \
+      std::vector<uint32_t> words(ExtractInts(result.data()));                 \
+      return const_mgr->GetConstant(type, words)->AsFloatConstant();           \
+    }                                                                          \
+    return nullptr;                                                            \
+  }
+
+// Define the folding rules for subtraction, addition, multiplication, and
+// division for floating point values.
+ConstantFoldingRule FoldFSub() { return FoldFloatingPointOp(FOLD_OP(-)); }
+ConstantFoldingRule FoldFAdd() { return FoldFloatingPointOp(FOLD_OP(+)); }
+ConstantFoldingRule FoldFMul() { return FoldFloatingPointOp(FOLD_OP(*)); }
+ConstantFoldingRule FoldFDiv() { return FoldFloatingPointOp(FOLD_OP(/)); }
 }  // namespace
 
 spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
@@ -92,7 +253,13 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
   // Take that into consideration.
 
   rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
+
   rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
+
+  rules_[SpvOpFAdd].push_back(FoldFAdd());
+  rules_[SpvOpFDiv].push_back(FoldFDiv());
+  rules_[SpvOpFMul].push_back(FoldFMul());
+  rules_[SpvOpFSub].push_back(FoldFSub());
 }
 }  // namespace opt
 }  // namespace spvtools
index 3eb3411..e382e16 100644 (file)
@@ -26,6 +26,7 @@
 #include "module.h"
 #include "type_manager.h"
 #include "types.h"
+#include "util/hex_float.h"
 
 namespace spvtools {
 namespace opt {
@@ -172,6 +173,27 @@ class FloatConstant : public ScalarConstant {
   std::unique_ptr<Constant> Copy() const override {
     return std::unique_ptr<Constant>(CopyFloatConstant().release());
   }
+
+  // Returns the float value of |this|.  The type of |this| must be |Float| with
+  // width of 32.
+  float GetFloatValue() const {
+    assert(type()->AsFloat()->width() == 32 &&
+           "Not a 32-bit floating point value.");
+    spvutils::FloatProxy<float> a(words()[0]);
+    return a.getAsFloat();
+  }
+
+  // Returns the double value of |this|.  The type of |this| must be |Float|
+  // with width of 64.
+  double GetDoubleValue() const {
+    assert(type()->AsFloat()->width() == 64 &&
+           "Not a 32-bit floating point value.");
+    uint64_t combined_words = words()[1];
+    combined_words = combined_words << 32;
+    combined_words |= words()[0];
+    spvutils::FloatProxy<double> a(combined_words);
+    return a.getAsFloat();
+  }
 };
 
 // Bool type constant.
@@ -269,7 +291,7 @@ class VectorConstant : public CompositeConstant {
     return std::unique_ptr<Constant>(CopyVectorConstant().release());
   }
 
-  const Type* component_type() { return component_type_; }
+  const Type* component_type() const { return component_type_; }
 
  private:
   const Type* component_type_;
index f3f51cd..ab7239d 100644 (file)
@@ -612,6 +612,10 @@ ir::Instruction* FoldInstructionToConstant(
   ir::IRContext* context = inst->context();
   analysis::ConstantManager* const_mgr = context->get_constant_mgr();
 
+  if (!inst->IsFoldable() &&
+      !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
+    return nullptr;
+  }
   // Collect the values of the constant parameters.
   std::vector<const analysis::Constant*> constants;
   bool missing_constants = false;
@@ -622,9 +626,9 @@ ir::Instruction* FoldInstructionToConstant(
     if (!const_op) {
       constants.push_back(nullptr);
       missing_constants = true;
-      return;
+    } else {
+      constants.push_back(const_op);
     }
-    constants.push_back(const_op);
   });
 
   if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
@@ -659,7 +663,6 @@ ir::Instruction* FoldInstructionToConstant(
         const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
     return const_mgr->GetDefiningInstruction(result_const);
   }
-
   return nullptr;
 }
 
@@ -679,7 +682,8 @@ bool IsFoldableType(ir::Instruction* type_inst) {
 bool FoldInstruction(ir::Instruction* inst) {
   bool modified = false;
   ir::Instruction* folded_inst(inst);
-  while (FoldInstructionInternal(&*folded_inst)) {
+  while (folded_inst->opcode() != SpvOpCopyObject &&
+         FoldInstructionInternal(&*folded_inst)) {
     modified = true;
   }
   return modified;
index 47706e8..356ab90 100644 (file)
@@ -82,7 +82,7 @@ bool SimplificationPass::SimplifyFunction(ir::Function* function) {
   for (size_t i = 0; i < work_list.size(); ++i) {
     ir::Instruction* inst = work_list[i];
     in_work_list.erase(inst);
-    if (FoldInstruction(inst)) {
+    if (inst->opcode() == SpvOpCopyObject || FoldInstruction(inst)) {
       modified = true;
       context()->AnalyzeUses(inst);
       get_def_use_mgr()->ForEachUser(
index ac40f90..0cc1e3a 100644 (file)
@@ -91,6 +91,9 @@ OpName %main "main"
 %void = OpTypeVoid
 %void_func = OpTypeFunction %void
 %bool = OpTypeBool
+%float16 = OpTypeFloat 16
+%float = OpTypeFloat 32
+%double = OpTypeFloat 64
 %101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
 %true = OpConstantTrue %bool
 %false = OpConstantFalse %bool
@@ -124,6 +127,19 @@ OpName %main "main"
 %102 = OpConstantComposite %v2int %103 %103
 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
 %struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
+%float16_0 = OpConstant %float16 0
+%float16_1 = OpConstant %float16 1
+%float16_2 = OpConstant %float16 2
+%float_n1 = OpConstant %float -1
+%float_0 = OpConstant %float 0
+%float_1 = OpConstant %float 1
+%float_2 = OpConstant %float 2
+%float_3 = OpConstant %float 3
+%double_n1 = OpConstant %double -1
+%double_0 = OpConstant %double 0
+%double_1 = OpConstant %double 1
+%double_2 = OpConstant %double 2
+%double_3 = OpConstant %double 3
 )";
 
   return header;
@@ -545,6 +561,183 @@ INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTest,
 ));
 // clang-format on
 
+using FloatInstructionFoldingTest =
+    ::testing::TestWithParam<InstructionFoldingCase<float>>;
+
+TEST_P(FloatInstructionFoldingTest, Case) {
+  const auto& tc = GetParam();
+
+  // Build module.
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  // Fold the instruction to test.
+  opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+  bool succeeded = opt::FoldInstruction(inst);
+
+  // Make sure the instruction folded as expected.
+  EXPECT_TRUE(succeeded);
+  if (inst != nullptr) {
+    EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+    inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+    EXPECT_EQ(inst->opcode(), SpvOpConstant);
+    opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+    const opt::analysis::FloatConstant* result =
+        const_mrg->GetConstantFromInst(inst)->AsFloatConstant();
+    EXPECT_NE(result, nullptr);
+    if (result != nullptr) {
+      EXPECT_EQ(result->GetFloatValue(), tc.expected_result);
+    }
+  }
+}
+
+// Not testing NaNs because there are no expectations concerning NaNs according
+// to the "Precision and Operation of SPIR-V Instructions" section of the Vulkan
+// specification.
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
+::testing::Values(
+    // Test case 0: Fold 2.0 - 1.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFSub %float %float_2 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 1.0),
+    // Test case 1: Fold 2.0 + 1.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFAdd %float %float_2 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 3.0),
+    // Test case 2: Fold 3.0 * 2.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFMul %float %float_3 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 6.0),
+    // Test case 3: Fold 1.0 / 2.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFDiv %float %float_1 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.5),
+    // Test case 4: Fold 1.0 / 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFDiv %float %float_1 %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, std::numeric_limits<float>::infinity()),
+    // Test case 4: Fold -1.0 / 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFDiv %float %float_n1 %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, -std::numeric_limits<float>::infinity())
+));
+// clang-format on
+
+using DoubleInstructionFoldingTest =
+    ::testing::TestWithParam<InstructionFoldingCase<double>>;
+
+TEST_P(DoubleInstructionFoldingTest, Case) {
+  const auto& tc = GetParam();
+
+  // Build module.
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  // Fold the instruction to test.
+  opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+  bool succeeded = opt::FoldInstruction(inst);
+
+  // Make sure the instruction folded as expected.
+  EXPECT_TRUE(succeeded);
+  if (inst != nullptr) {
+    EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+    inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+    EXPECT_EQ(inst->opcode(), SpvOpConstant);
+    opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+    const opt::analysis::FloatConstant* result =
+        const_mrg->GetConstantFromInst(inst)->AsFloatConstant();
+    EXPECT_NE(result, nullptr);
+    if (result != nullptr) {
+      EXPECT_EQ(result->GetDoubleValue(), tc.expected_result);
+    }
+  }
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest,
+::testing::Values(
+    // Test case 0: Fold 2.0 - 1.0
+    InstructionFoldingCase<double>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpFSub %double %double_2 %double_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 1.0),
+        // Test case 1: Fold 2.0 + 1.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%2 = OpFAdd %double %double_2 %double_1\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, 3.0),
+        // Test case 2: Fold 3.0 * 2.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%2 = OpFMul %double %double_3 %double_2\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, 6.0),
+        // Test case 3: Fold 1.0 / 2.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%2 = OpFDiv %double %double_1 %double_2\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, 0.5),
+        // Test case 4: Fold 1.0 / 0.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%2 = OpFDiv %double %double_1 %double_0\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, std::numeric_limits<double>::infinity()),
+        // Test case 4: Fold -1.0 / 0.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%2 = OpFDiv %double %double_n1 %double_0\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, -std::numeric_limits<double>::infinity())
+));
+// clang-format on
 template <class ResultType>
 struct InstructionFoldingCaseWithMap {
   InstructionFoldingCaseWithMap(const std::string& tb, uint32_t id,