Avoid redundant predicate checking in type matching.
authorJacques Pienaar <jpienaar@google.com>
Fri, 11 Jan 2019 15:41:12 +0000 (07:41 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:09:25 +0000 (15:09 -0700)
Expand type matcher template generator to consider a set of predicates that are known to
hold. This avoids inserting redundant checking for trivially true predicates
(for example predicate that hold according to the op definition). This only targets predicates that trivially holds and does not attempt any logic equivalence proof.

PiperOrigin-RevId: 228880468

mlir/include/mlir/TableGen/Predicate.h
mlir/lib/TableGen/Operator.cpp
mlir/lib/TableGen/Predicate.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index a89fbf15e483dbf70f6d60b7690aacb8931ca246..a667d92788c32eadca63052564620d550fed22d3 100644 (file)
@@ -60,7 +60,7 @@ public:
 
   // Returns the template string to construct the matcher corresponding to this
   // predicate CNF. The string uses '{0}' to represent the type.
-  std::string createTypeMatcherTemplate() const;
+  std::string createTypeMatcherTemplate(PredCNF predsKnownToHold) const;
 
 private:
   // The TableGen definition of this predicate CNF. nullptr means an empty
index 595cf8a59b756ea39fd2fb653a423703af75b642..61b6d1745b5dc89dccee45cc67e22e96504f59e9 100644 (file)
@@ -163,5 +163,6 @@ bool tblgen::Operator::Operand::hasMatcher() const {
 }
 
 std::string tblgen::Operator::Operand::createTypeMatcherTemplate() const {
-  return tblgen::Type(defInit).getPredicate().createTypeMatcherTemplate();
+  return tblgen::Type(defInit).getPredicate().createTypeMatcherTemplate(
+      PredCNF());
 }
index 88d08f565c9da80303d2957c7f33ca5e025839b2..b297fff80f5e62f857bc7a23264e96eb2d99e0f1 100644 (file)
@@ -20,6 +20,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Error.h"
@@ -42,15 +43,29 @@ const llvm::ListInit *tblgen::PredCNF::getConditions() const {
   return def->getValueAsListInit("conditions");
 }
 
-std::string tblgen::PredCNF::createTypeMatcherTemplate() const {
+std::string
+tblgen::PredCNF::createTypeMatcherTemplate(PredCNF predsKnownToHold) const {
   const auto *conjunctiveList = getConditions();
   if (!conjunctiveList)
     return "true";
 
+  // Create a set of all the disjunctive conditions that hold. This is taking
+  // advantage of uniquieing of lists to discard based on the pointer
+  // below. This is not perfect but this will also be moved to FSM matching in
+  // future and gets rid of trivial redundant checking.
+  llvm::SmallSetVector<const llvm::Init *, 4> existingConditions;
+  auto existingList = predsKnownToHold.getConditions();
+  if (existingList) {
+    for (auto disjunctiveInit : *existingList)
+      existingConditions.insert(disjunctiveInit);
+  }
+
   std::string outString;
   llvm::raw_string_ostream ss(outString);
   bool firstDisjunctive = true;
   for (auto disjunctiveInit : *conjunctiveList) {
+    if (existingConditions.count(disjunctiveInit) != 0)
+      continue;
     ss << (firstDisjunctive ? "(" : " && (");
     firstDisjunctive = false;
     bool firstConjunctive = true;
@@ -63,6 +78,8 @@ std::string tblgen::PredCNF::createTypeMatcherTemplate() const {
     }
     ss << ")";
   }
+  if (firstDisjunctive)
+    return "true";
   ss.flush();
   return outString;
 }
index b4e38954ec7965fbcd1e247dd3fa8fefb65fd33c..615b08fa1571e5dc4e0a48d542aebaffb66f89ef 100644 (file)
@@ -175,10 +175,10 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
                           "type argument required for operand");
 
         auto pred = tblgen::Type(defInit).getPredicate();
-
+        auto opPred = tblgen::Type(operand->defInit).getPredicate();
         os.indent(indent)
             << "if (!("
-            << formatv(pred.createTypeMatcherTemplate().c_str(),
+            << formatv(pred.createTypeMatcherTemplate(opPred).c_str(),
                        formatv("op{0}->getOperand({1})->getType()", depth, i))
             << ")) return matchFailure();\n";
       }