// Returns the template string to construct the matcher corresponding to this
// predicate CNF. The string uses '{0}' to represent the type.
- std::string createTypeMatcherTemplate() const;
+ std::string createTypeMatcherTemplate(PredCNF predsKnownToHold) const;
private:
// The TableGen definition of this predicate CNF. nullptr means an empty
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
return def->getValueAsListInit("conditions");
}
-std::string tblgen::PredCNF::createTypeMatcherTemplate() const {
+std::string
+tblgen::PredCNF::createTypeMatcherTemplate(PredCNF predsKnownToHold) const {
const auto *conjunctiveList = getConditions();
if (!conjunctiveList)
return "true";
+ // Create a set of all the disjunctive conditions that hold. This is taking
+ // advantage of uniquieing of lists to discard based on the pointer
+ // below. This is not perfect but this will also be moved to FSM matching in
+ // future and gets rid of trivial redundant checking.
+ llvm::SmallSetVector<const llvm::Init *, 4> existingConditions;
+ auto existingList = predsKnownToHold.getConditions();
+ if (existingList) {
+ for (auto disjunctiveInit : *existingList)
+ existingConditions.insert(disjunctiveInit);
+ }
+
std::string outString;
llvm::raw_string_ostream ss(outString);
bool firstDisjunctive = true;
for (auto disjunctiveInit : *conjunctiveList) {
+ if (existingConditions.count(disjunctiveInit) != 0)
+ continue;
ss << (firstDisjunctive ? "(" : " && (");
firstDisjunctive = false;
bool firstConjunctive = true;
}
ss << ")";
}
+ if (firstDisjunctive)
+ return "true";
ss.flush();
return outString;
}
"type argument required for operand");
auto pred = tblgen::Type(defInit).getPredicate();
-
+ auto opPred = tblgen::Type(operand->defInit).getPredicate();
os.indent(indent)
<< "if (!("
- << formatv(pred.createTypeMatcherTemplate().c_str(),
+ << formatv(pred.createTypeMatcherTemplate(opPred).c_str(),
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
}