Add folding of OpCompositeExtract and OpConstantComposite constant instructions.
authorSteven Perron <stevenperron@google.com>
Thu, 8 Feb 2018 15:59:03 +0000 (10:59 -0500)
committerDavid Neto <dneto@google.com>
Fri, 9 Feb 2018 22:52:33 +0000 (17:52 -0500)
Create files for constant folding rules.

Add the rules for OpConstantComposite and OpCompositeExtract.

13 files changed:
Android.mk
source/opt/CMakeLists.txt
source/opt/const_folding_rules.cpp [new file with mode: 0644]
source/opt/const_folding_rules.h [new file with mode: 0644]
source/opt/constants.h
source/opt/def_use_manager.cpp
source/opt/def_use_manager.h
source/opt/fold.cpp
source/opt/folding_rules.cpp
source/opt/folding_rules.h
source/opt/ir_context.h
test/opt/fold_test.cpp
test/opt/simplification_test.cpp

index 501d8db..8f5e55b 100644 (file)
@@ -63,9 +63,10 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/cfg.cpp \
                source/opt/cfg_cleanup_pass.cpp \
                source/opt/ccp_pass.cpp \
+               source/opt/common_uniform_elim_pass.cpp \
                source/opt/compact_ids_pass.cpp \
                source/opt/composite.cpp \
-               source/opt/common_uniform_elim_pass.cpp \
+               source/opt/const_folding_rules.cpp \
                source/opt/constants.cpp \
                source/opt/dead_branch_elim_pass.cpp \
                source/opt/dead_insert_elim_pass.cpp \
index 906164d..b3c6ebe 100644 (file)
@@ -22,6 +22,7 @@ add_library(SPIRV-Tools-opt
   common_uniform_elim_pass.h
   compact_ids_pass.h
   composite.h
+  const_folding_rules.h
   constants.h
   dead_branch_elim_pass.h
   dead_insert_elim_pass.h
@@ -94,6 +95,7 @@ add_library(SPIRV-Tools-opt
   common_uniform_elim_pass.cpp
   compact_ids_pass.cpp
   composite.cpp
+  const_folding_rules.cpp
   constants.cpp
   dead_branch_elim_pass.cpp
   dead_insert_elim_pass.cpp
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
new file mode 100644 (file)
index 0000000..f4492db
--- /dev/null
@@ -0,0 +1,98 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "const_folding_rules.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+const uint32_t kExtractCompositeIdInIdx = 0;
+
+ConstantFoldingRule FoldExtractWithConstants() {
+  // Folds an OpcompositeExtract where input is a composite constant.
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
+    if (c == nullptr) {
+      return nullptr;
+    }
+
+    for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
+      uint32_t element_index = inst->GetSingleWordInOperand(i);
+      if (c->AsNullConstant()) {
+        // Return Null for the return type.
+        ir::IRContext* context = inst->context();
+        analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+        analysis::TypeManager* type_mgr = context->get_type_mgr();
+        const analysis::NullConstant null_const(
+            type_mgr->GetType(inst->type_id()));
+        const analysis::Constant* real_const =
+            const_mgr->FindConstant(&null_const);
+        if (real_const == nullptr) {
+          ir::Instruction* const_inst =
+              const_mgr->GetDefiningInstruction(&null_const);
+          real_const = const_mgr->GetConstantFromInst(const_inst);
+        }
+        return real_const;
+      }
+
+      auto cc = c->AsCompositeConstant();
+      assert(cc != nullptr);
+      auto components = cc->GetComponents();
+      c = components[element_index];
+    }
+    return c;
+  };
+}
+
+ConstantFoldingRule FoldCompositeWithConstants() {
+  // Folds an OpCompositeConstruct where all of the inputs are constants to a
+  // constant.  A new constant is created if necessary.
+  return [](ir::Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    ir::IRContext* context = inst->context();
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    analysis::TypeManager* type_mgr = context->get_type_mgr();
+    const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
+
+    std::vector<uint32_t> ids;
+    for (const analysis::Constant* element_const : constants) {
+      if (element_const == nullptr) {
+        return nullptr;
+      }
+      uint32_t element_id = const_mgr->FindDeclaredConstant(element_const);
+      if (element_id == 0) {
+        return nullptr;
+      }
+      ids.push_back(element_id);
+    }
+    return const_mgr->GetConstant(new_type, ids);
+  };
+}
+}  // namespace
+
+spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
+  // Add all folding rules to the list for the opcodes to which they apply.
+  // Note that the order in which rules are added to the list matters. If a rule
+  // applies to the instruction, the rest of the rules will not be attempted.
+  // Take that into consideration.
+
+  rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
+  rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
+}
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/const_folding_rules.h b/source/opt/const_folding_rules.h
new file mode 100644 (file)
index 0000000..fb7dad3
--- /dev/null
@@ -0,0 +1,84 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_CONST_FOLDING_RULES_H_
+#define LIBSPIRV_OPT_CONST_FOLDING_RULES_H_
+
+#include <vector>
+
+#include "../../external/spirv-headers/include/spirv/1.2/spirv.h"
+#include "constants.h"
+#include "def_use_manager.h"
+#include "folding_rules.h"
+#include "ir_builder.h"
+#include "ir_context.h"
+
+namespace spvtools {
+namespace opt {
+
+// Constant Folding Rules:
+//
+// The folding mechanism is built around the concept of a |ConstantFoldingRule|.
+// A constant folding rule is a function that implements a method of simplifying
+// an instruction to a constant.
+//
+// The inputs to a folding rule are:
+//     |inst| - the instruction to be simplified.
+//     |constants| - if an in-operands is an id of a constant, then the
+//                   corresponding value in |constants| contains that
+//                   constant value.  Otherwise, the corresponding entry in
+//                   |constants| is |nullptr|.
+//
+// A constant folding rule returns a pointer to an Constant if |inst| can be
+// simplified using this rule. Otherwise, it returns |nullptr|.
+//
+// See const_folding_rules.cpp for examples on how to write a constant folding
+// rule.
+//
+// Be sure to add new constant folding rules to the table of constant folding
+// rules in the constructor for ConstantFoldingRules.  The new rule should be
+// added to the list for every opcode that it applies to.  Note that earlier
+// rules in the list are given priority.  That is, if an earlier rule is able to
+// fold an instruction, the later rules will not be attempted.
+
+using ConstantFoldingRule = std::function<const analysis::Constant*(
+    ir::Instruction* inst,
+    const std::vector<const analysis::Constant*>& constants)>;
+
+class ConstantFoldingRules {
+ public:
+  ConstantFoldingRules();
+
+  // Returns true if there is at least 1 folding rule for |opcode|.
+  bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); }
+
+  // Returns an vector of constant folding rules for |opcode|.
+  const std::vector<ConstantFoldingRule>& GetRulesForOpcode(
+      SpvOp opcode) const {
+    auto it = rules_.find(opcode);
+    if (it != rules_.end()) {
+      return it->second;
+    }
+    return empty_vector_;
+  }
+
+ private:
+  std::unordered_map<uint32_t, std::vector<ConstantFoldingRule>> rules_;
+  std::vector<ConstantFoldingRule> empty_vector_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_CONST_FOLDING_RULES_H_
index 47cfa8a..3eb3411 100644 (file)
@@ -204,7 +204,7 @@ class CompositeConstant : public Constant {
   CompositeConstant* AsCompositeConstant() override { return this; }
   const CompositeConstant* AsCompositeConstant() const override { return this; }
 
-  // Returns a const reference of the components holded in this composite
+  // Returns a const reference of the components held in this composite
   // constant.
   virtual const std::vector<const Constant*>& GetComponents() const {
     return components_;
index 33776ce..db91f59 100644 (file)
@@ -71,6 +71,18 @@ void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
   AnalyzeInstUse(inst);
 }
 
+void DefUseManager::UpdateDefUse(ir::Instruction* inst) {
+  const uint32_t def_id = inst->result_id();
+  if (def_id != 0) {
+    auto iter = id_to_def_.find(def_id);
+    if (iter != id_to_def_.end()) {
+      AnalyzeInstDef(inst);
+    } else {
+    }
+  }
+  AnalyzeInstUse(inst);
+}
+
 ir::Instruction* DefUseManager::GetDef(uint32_t id) {
   auto iter = id_to_def_.find(id);
   if (iter == id_to_def_.end()) return nullptr;
index 69047e2..2061703 100644 (file)
@@ -214,6 +214,10 @@ class DefUseManager {
     return !(lhs == rhs);
   }
 
+  // If |inst| has not already been analysed, then analyses its defintion and
+  // uses.
+  void UpdateDefUse(ir::Instruction* inst);
+
  private:
   using InstToUsedIdsMap =
       std::unordered_map<const ir::Instruction*, std::vector<uint32_t>>;
index 04c0446..f3f51cd 100644 (file)
 
 #include "fold.h"
 
+#include <cassert>
+#include <cstdint>
+#include <vector>
+
+#include "const_folding_rules.h"
 #include "def_use_manager.h"
 #include "folding_rules.h"
 #include "ir_builder.h"
 #include "ir_context.h"
 
-#include <cassert>
-#include <cstdint>
-#include <vector>
-
 namespace spvtools {
 namespace opt {
 
@@ -40,6 +41,11 @@ namespace {
 #define UINT32_MAX 0xffffffff /* 4294967295U */
 #endif
 
+const ConstantFoldingRules& GetConstantFoldingRules() {
+  static ConstantFoldingRules* rules = new ConstantFoldingRules();
+  return *rules;
+}
+
 // Returns the single-word result from performing the given unary operation on
 // the operand value which is passed in as a 32-bit word.
 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
@@ -603,10 +609,6 @@ bool IsFoldableConstant(const analysis::Constant* cst) {
 
 ir::Instruction* FoldInstructionToConstant(
     ir::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) {
-  if (!inst->IsFoldable()) {
-    return nullptr;
-  }
-
   ir::IRContext* context = inst->context();
   analysis::ConstantManager* const_mgr = context->get_constant_mgr();
 
@@ -617,7 +619,7 @@ ir::Instruction* FoldInstructionToConstant(
                      &id_map](uint32_t* op_id) {
     uint32_t id = id_map(*op_id);
     const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
-    if (!const_op || !IsFoldableConstant(const_op)) {
+    if (!const_op) {
       constants.push_back(nullptr);
       missing_constants = true;
       return;
@@ -625,15 +627,30 @@ ir::Instruction* FoldInstructionToConstant(
     constants.push_back(const_op);
   });
 
+  if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
+    const analysis::Constant* folded_const = nullptr;
+    for (auto rule :
+         GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
+      folded_const = rule(inst, constants);
+      if (folded_const != nullptr) {
+        ir::Instruction* const_inst =
+            const_mgr->GetDefiningInstruction(folded_const);
+        // May be a new instruction that needs to be analysed.
+        context->UpdateDefUse(const_inst);
+        return const_inst;
+      }
+    }
+  }
+
   uint32_t result_val = 0;
   bool successful = false;
   // If all parameters are constant, fold the instruction to a constant.
-  if (!missing_constants) {
+  if (!missing_constants && inst->IsFoldable()) {
     result_val = FoldScalars(inst->opcode(), constants);
     successful = true;
   }
 
-  if (!successful) {
+  if (!successful && inst->IsFoldable()) {
     successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
   }
 
index ed33dfd..4762e75 100644 (file)
@@ -69,8 +69,8 @@ FoldingRule CompositeConstructFeedingExtract() {
 
       // Add the remaining indices for extraction.
       for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
-        operands.push_back(
-            {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(i)}});
+        operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
+                            {inst->GetSingleWordInOperand(i)}});
       }
 
     } else {
@@ -302,14 +302,14 @@ spvtools::opt::FoldingRules::FoldingRules() {
   // applies to the instruction, the rest of the rules will not be attempted.
   // Take that into consideration.
 
-  rules[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
+  rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
 
-  rules[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
-  rules[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
+  rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
+  rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
 
-  rules[SpvOpIMul].push_back(IntMultipleBy1());
+  rules_[SpvOpIMul].push_back(IntMultipleBy1());
 
-  rules[SpvOpPhi].push_back(RedundantPhi());
+  rules_[SpvOpPhi].push_back(RedundantPhi());
 }
 }  // namespace opt
 }  // namespace spvtools
index 8ac940a..78277e8 100644 (file)
@@ -63,15 +63,15 @@ class FoldingRules {
   FoldingRules();
 
   const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) {
-    auto it = rules.find(opcode);
-    if (it != rules.end()) {
+    auto it = rules_.find(opcode);
+    if (it != rules_.end()) {
       return it->second;
     }
     return empty_vector_;
   }
 
  private:
-  std::unordered_map<uint32_t, std::vector<FoldingRule>> rules;
+  std::unordered_map<uint32_t, std::vector<FoldingRule>> rules_;
   std::vector<FoldingRule> empty_vector_;
 };
 
index 03846c2..d209ed3 100644 (file)
@@ -399,6 +399,10 @@ class IRContext {
   // Returns the grammar for this context.
   const libspirv::AssemblyGrammar& grammar() const { return grammar_; }
 
+  // If |inst| has not yet been analysed by the def-use manager, then analyse
+  // its definitions and uses.
+  inline void UpdateDefUse(Instruction* inst);
+
  private:
   // Builds the def-use manager from scratch, even if it was already valid.
   void BuildDefUseManager() {
@@ -723,6 +727,12 @@ void IRContext::AnalyzeDefUse(Instruction* inst) {
   }
 }
 
+void IRContext::UpdateDefUse(Instruction* inst) {
+  if (AreAnalysesValid(kAnalysisDefUse)) {
+    get_def_use_mgr()->UpdateDefUse(inst);
+  }
+}
+
 }  // namespace ir
 }  // namespace spvtools
 #endif  // SPIRV_TOOLS_IR_CONTEXT_H
index 7d2567d..ac40f90 100644 (file)
@@ -78,6 +78,8 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
 // Returns a common SPIR-V header for all of the test that follow.
 #define INT_0_ID 100
 #define TRUE_ID 101
+#define VEC2_0_ID 102
+#define INT_7_ID 103
 const std::string& Header() {
   static const std::string header = R"(OpCapability Shader
 %1 = OpExtInstImport "GLSL.std.450"
@@ -89,8 +91,8 @@ OpName %main "main"
 %void = OpTypeVoid
 %void_func = OpTypeFunction %void
 %bool = OpTypeBool
-%true = OpConstantTrue %bool
 %101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
+%true = OpConstantTrue %bool
 %false = OpConstantFalse %bool
 %short = OpTypeInt 16 1
 %int = OpTypeInt 32 1
@@ -104,9 +106,10 @@ OpName %main "main"
 %_ptr_bool = OpTypePointer Function %bool
 %short_0 = OpConstant %short 0
 %short_3 = OpConstant %short 3
+%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
+%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
 %int_0 = OpConstant %int 0
 %int_1 = OpConstant %int 1
-%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
 %int_3 = OpConstant %int 3
 %int_min = OpConstant %int -2147483648
 %int_max = OpConstant %int 2147483647
@@ -116,8 +119,11 @@ OpName %main "main"
 %uint_3 = OpConstant %uint 3
 %uint_32 = OpConstant %uint 32
 %uint_max = OpConstant %uint -1
+%v2int_undef = OpUndef %v2int
 %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
+%102 = OpConstantComposite %v2int %103 %103
 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
+%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
 )";
 
   return header;
@@ -1227,7 +1233,23 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
             "%5 = OpCompositeExtract %int %4 0 1\n" +
             "OpReturn\n" +
             "OpFunctionEnd",
-        5, 2)
+        5, 2),
+    // Test case 7: fold constant extract.
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeExtract %int %102 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, INT_7_ID),
+    // Test case 8: constant struct has OpUndef
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeExtract %int %struct_undef_0_0 0 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0)
 ));
 
 INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
@@ -1282,7 +1304,15 @@ INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFolding
             "%7 = OpCompositeConstruct %v4int %3 %4 %5\n" +
             "OpReturn\n" +
             "OpFunctionEnd",
-        7, 0)
+        7, 0),
+    // Test case 4: Fold construct with constants to constant.
+    InstructionFoldingCase<uint32_t>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeConstruct %v2int %103 %103\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, VEC2_0_ID)
 ));
 
 INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest,
index 76531bc..6e1ff23 100644 (file)
@@ -92,8 +92,6 @@ TEST_F(SimplificationTest, AcrossBasicBlocks) {
         %int = OpTypeInt 32 1
       %v4int = OpTypeVector %int 4
       %int_0 = OpConstant %int 0
-; CHECK: [[constant:%[a-zA-Z_\d]+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
-         %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
 %_ptr_Input_v4int = OpTypePointer Input %v4int
           %i = OpVariable %_ptr_Input_v4int Input
        %uint = OpTypeInt 32 0
@@ -115,14 +113,14 @@ TEST_F(SimplificationTest, AcrossBasicBlocks) {
                OpSelectionMerge %30 None
                OpBranchConditional %29 %31 %32
          %31 = OpLabel
-         %43 = OpCopyObject %v4int %13
+         %43 = OpCopyObject %v4int %25
                OpBranch %30
          %32 = OpLabel
-         %45 = OpCopyObject %v4int %13
+         %45 = OpCopyObject %v4int %25
                OpBranch %30
          %30 = OpLabel
          %50 = OpPhi %v4int %43 %31 %45 %32
-; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0
+; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0
          %47 = OpCompositeExtract %int %50 0
 ; CHECK: [[extract2:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1
          %49 = OpCompositeExtract %int %41 1
@@ -170,9 +168,11 @@ TEST_F(SimplificationTest, ThroughLoops) {
          %68 = OpUndef %v4int
        %main = OpFunction %void None %8
          %23 = OpLabel
+; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad %v4int %i
+       %load = OpLoad %v4int %i
                OpBranch %24
          %24 = OpLabel
-         %67 = OpPhi %v4int %13 %23 %64 %26
+         %67 = OpPhi %v4int %load %23 %64 %26
 ; CHECK: OpLoopMerge [[merge_lab:%[a-zA-Z_\d]+]]
                OpLoopMerge %25 %26 None
                OpBranch %27
@@ -191,7 +191,7 @@ TEST_F(SimplificationTest, ThroughLoops) {
                OpBranch %24
          %25 = OpLabel
 ; CHECK: [[merge_lab]] = OpLabel
-; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0
+; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0
          %66 = OpCompositeExtract %int %67 0
 ; CHECK-NEXT: OpStore %o [[extract]]
                OpStore %o %66