Fixes #1357. Support null constants better in folding
authorAlan Baker <alanbaker@google.com>
Wed, 28 Feb 2018 20:23:19 +0000 (15:23 -0500)
committerSteven Perron <stevenperron@google.com>
Thu, 1 Mar 2018 04:12:27 +0000 (23:12 -0500)
* getFloatConstantKind() now handles OpConstantNull
* PerformOperation() now handles OpConstantNull for vectors
* Fixed some instances where we would attempt to merge a division by 0
* added tests

source/opt/constants.h
source/opt/folding_rules.cpp
test/opt/fold_test.cpp

index cd3134b..999dc52 100644 (file)
@@ -126,6 +126,18 @@ class ScalarConstant : public Constant {
   // Returns a const reference of the value of this constant in 32-bit words.
   virtual const std::vector<uint32_t>& words() const { return words_; }
 
+  // Returns true if the value is zero.
+  bool IsZero() const {
+    bool is_zero = true;
+    for (uint32_t v : words()) {
+      if (v != 0) {
+        is_zero = false;
+        break;
+      }
+    }
+    return is_zero;
+  }
+
  protected:
   ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
       : Constant(ty), words_(w) {}
@@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant {
            static_cast<uint64_t>(words()[0]);
   }
 
-  bool IsZero() const {
-    bool is_zero = true;
-    for (uint32_t v : words()) {
-      if (v != 0) {
-        is_zero = false;
-        break;
-      }
-    }
-    return is_zero;
-  }
-
   // Make a copy of this IntConstant instance.
   std::unique_ptr<IntConstant> CopyIntConstant() const {
     return MakeUnique<IntConstant>(type_->AsInteger(), words_);
index f94ba7b..7e4dddb 100644 (file)
@@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() {
         const analysis::Constant* negated_const =
             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
-      } else {
+      } else if (constants[1]->AsFloatConstant()) {
         id = Reciprocal(const_mgr, constants[1]);
         if (id == 0) return false;
+      } else {
+        // Don't fold a null constant.
+        return false;
       }
       inst->SetOpcode(SpvOpFMul);
       inst->SetInOperands(
@@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() {
   };
 }
 
+// Returns true if |c| has a zero element.
+bool HasZero(const analysis::Constant* c) {
+  if (c->AsNullConstant()) {
+    return true;
+  }
+  if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
+    for (auto& comp : vec_const->GetComponents())
+      if (HasZero(comp)) return true;
+  } else {
+    assert(c->AsScalarConstant());
+    return c->AsScalarConstant()->IsZero();
+  }
+
+  return false;
+}
+
 // Performs |input1| |opcode| |input2| and returns the merged constant result
 // id. Returns 0 if the result is not a valid value. The input types must be
 // Float.
@@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
       FOLD_OP(*);
       break;
     case SpvOpFDiv:
+      if (HasZero(input2)) return 0;
       FOLD_OP(/);
       break;
     case SpvOpFAdd:
@@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
     const analysis::Type* ele_type = vector_type->element_type();
     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
       uint32_t id = 0;
-      const analysis::Constant* input1_comp =
-          input1->AsVectorConstant()->GetComponents()[i];
-      const analysis::Constant* input2_comp =
-          input2->AsVectorConstant()->GetComponents()[i];
+
+      const analysis::Constant* input1_comp = nullptr;
+      if (const analysis::VectorConstant* input1_vector =
+              input1->AsVectorConstant()) {
+        input1_comp = input1_vector->GetComponents()[i];
+      } else {
+        assert(input1->AsNullConstant());
+        input1_comp = const_mgr->GetConstant(ele_type, {});
+      }
+
+      const analysis::Constant* input2_comp = nullptr;
+      if (const analysis::VectorConstant* input2_vector =
+              input2->AsVectorConstant()) {
+        input2_comp = input2_vector->GetComponents()[i];
+      } else {
+        assert(input2->AsNullConstant());
+        input2_comp = const_mgr->GetConstant(ele_type, {});
+      }
+
       if (ele_type->AsFloat()) {
         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
                                            input2_comp);
@@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() {
       std::vector<const analysis::Constant*> other_constants =
           const_mgr->GetOperandConstants(other_inst);
       const analysis::Constant* const_input2 = ConstInput(other_constants);
-      if (!const_input2) return false;
+      if (!const_input2 || HasZero(const_input2)) return false;
 
       bool other_first_is_variable = other_constants[0] == nullptr;
       // If the variable value is the second operand of the divide, multiply
@@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() {
     if (width != 32 && width != 64) return false;
 
     const analysis::Constant* const_input1 = ConstInput(constants);
-    if (!const_input1) return false;
+    if (!const_input1 || HasZero(const_input1)) return false;
     ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
 
@@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() {
       std::vector<const analysis::Constant*> other_constants =
           const_mgr->GetOperandConstants(other_inst);
       const analysis::Constant* const_input2 = ConstInput(other_constants);
-      if (!const_input2) return false;
+      if (!const_input2 || HasZero(const_input2)) return false;
 
       bool other_first_is_variable = other_constants[0] == nullptr;
 
@@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() {
     if (width != 32 && width != 64) return false;
 
     const analysis::Constant* const_input1 = ConstInput(constants);
-    if (!const_input1) return false;
+    if (!const_input1 || HasZero(const_input1)) return false;
     ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
 
@@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
     return FloatConstantKind::Unknown;
   }
 
-  if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
+  assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
+
+  if (constant->AsNullConstant()) {
+    return FloatConstantKind::Zero;
+  } else if (const analysis::VectorConstant* vc =
+                 constant->AsVectorConstant()) {
     const std::vector<const analysis::Constant*>& components =
         vc->GetComponents();
     assert(!components.empty());
index 4e418b9..345cfea 100644 (file)
@@ -198,6 +198,7 @@ OpName %main "main"
 %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
+%v2float_null = OpConstantNull %v2float
 %double_n1 = OpConstant %double -1
 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
 %double_0 = OpConstant %double 0
@@ -2526,7 +2527,37 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
             "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
             "OpReturn\n" +
             "OpFunctionEnd",
-        2, 4)
+        2, 4),
+    // Test case 15: Fold vector fadd with null
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_v2float Function\n" +
+            "%2 = OpLoad %v2float %a\n" +
+            "%3 = OpFAdd %v2float %2 %v2float_null\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        3, 2),
+    // Test case 16: Fold vector fadd with null
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_v2float Function\n" +
+            "%2 = OpLoad %v2float %a\n" +
+            "%3 = OpFAdd %v2float %v2float_null %2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        3, 2),
+    // Test case 15: Fold vector fsub with null
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%a = OpVariable %_ptr_v2float Function\n" +
+            "%2 = OpLoad %v2float %a\n" +
+            "%3 = OpFSub %v2float %2 %v2float_null\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        3, 2)
 ));
 
 INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
@@ -3317,7 +3348,18 @@ INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
       "%3 = OpFDiv %double %2 %double_2\n" +
       "OpReturn\n" +
       "OpFunctionEnd\n",
-    3, true)
+    3, true),
+  // Test case 4: don't fold x / 0.
+  InstructionFoldingCase<bool>(
+    Header() +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2float Function\n" +
+      "%2 = OpLoad %v2float %var\n" +
+      "%3 = OpFDiv %v2float %2 %v2float_null\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    3, false)
 ));
 
 INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
@@ -3812,7 +3854,20 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
       "%4 = OpSDiv %int %int_2 %3\n" +
       "OpReturn\n" +
       "OpFunctionEnd\n",
-    4, true)
+    4, true),
+  // Test case 13: Don't merge
+  // (x / {null}) / {null}
+  InstructionFoldingCase<bool>(
+    Header() +
+      "%main = OpFunction %void None %void_func\n" +
+      "%main_lab = OpLabel\n" +
+      "%var = OpVariable %_ptr_v2float Function\n" +
+      "%2 = OpLoad %float %var\n" +
+      "%3 = OpFDiv %float %2 %v2float_null\n" +
+      "%4 = OpFDiv %float %3 %v2float_null\n" +
+      "OpReturn\n" +
+      "OpFunctionEnd\n",
+    4, false)
 ));
 
 INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,