This CL allows user to specify the same name for the operands in the source pattern which implicitly enforces equality on operands with the same name.
E.g., Pat<(OpA $a, $b, $a) ... > would create a matching rule for checking equality for the first and the last operands. Equality of the operands is enforced at any depth, e.g., OpA ($a, $b, OpB($a, $c, OpC ($a))).
Example usage: Pat<(Reshape $arg0, (Shape $arg0)), (replaceWithValue $arg0)>
Note, this feature only covers operands but not attributes.
Current use cases are based on the operand equality and explicitly add the constraint into the pattern. Attribute equality will be worked out on the different CL.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D89254
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
+#include <unordered_map>
+
namespace llvm {
class DagInit;
class Init;
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;
+ // Returns a variable name for the symbol named as `name`.
+ std::string getVarName(StringRef name) const;
+
private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
+ // Alternative name for the symbol. It is used in case the name
+ // is not unique. Applicable for `Operand` only.
+ Optional<std::string> alternativeName;
};
- using BaseT = llvm::StringMap<SymbolInfo>;
+ using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
const_iterator end() const { return symbolInfoMap.end(); }
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
- // Returns false if `symbol` is already bound.
+ // Returns false if `symbol` is already bound and symbols are not operands.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;
+ // Returns an iterator to the information of the given symbol named as `key`,
+ // with index `argIndex` for operator `op`.
+ const_iterator findBoundSymbol(StringRef key, const Operator &op,
+ int argIndex) const;
+
+ // Returns the bounds of a range that includes all the elements which
+ // bind to the `key`.
+ std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
+
+ // Returns number of times symbol named as `key` was used.
+ int count(StringRef key) const;
+
// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
+ // Assign alternative unique names to Operands that have equal names.
+ void assignUniqueAlternativeNames();
+
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
private:
- llvm::StringMap<SymbolInfo> symbolInfoMap;
+ BaseT symbolInfoMap;
// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
llvm_unreachable("unknown kind");
}
+std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
+ return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
+}
+
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
- return std::string(formatv(
- "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
+ return std::string(
+ formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
+ getVarName(name)));
}
case Kind::Value: {
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);
- return symbolInfoMap.insert({symbol, symInfo}).second;
+ std::string key = symbol.str();
+ if (symbolInfoMap.count(key)) {
+ // Only non unique name for the operand is supported.
+ if (symInfo.kind != SymbolInfo::Kind::Operand) {
+ return false;
+ }
+
+ // Cannot add new operand if there is already non operand with the same
+ // name.
+ if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
+ return false;
+ }
+ }
+
+ symbolInfoMap.emplace(key, symInfo);
+ return true;
}
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
- StringRef name = getValuePackName(symbol);
- return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
+ std::string name = getValuePackName(symbol).str();
+ auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
+
+ return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindValue(StringRef symbol) {
- return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
+ auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
+ return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::contains(StringRef symbol) const {
}
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
- StringRef name = getValuePackName(key);
+ std::string name = getValuePackName(key).str();
+
return symbolInfoMap.find(name);
}
+SymbolInfoMap::const_iterator
+SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
+ int argIndex) const {
+ std::string name = getValuePackName(key).str();
+ auto range = symbolInfoMap.equal_range(name);
+
+ for (auto it = range.first; it != range.second; ++it) {
+ if (it->second.op == &op && it->second.argIndex == argIndex) {
+ return it;
+ }
+ }
+
+ return symbolInfoMap.end();
+}
+
+std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
+SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
+ std::string name = getValuePackName(key).str();
+
+ return symbolInfoMap.equal_range(name);
+}
+
+int SymbolInfoMap::count(StringRef key) const {
+ std::string name = getValuePackName(key).str();
+ return symbolInfoMap.count(name);
+}
+
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
- return find(name)->getValue().getStaticValueCount();
+ return find(name)->second.getStaticValueCount();
}
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
- auto it = symbolInfoMap.find(name);
+ auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
- return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
+ return it->second.getValueAndRangeUse(name, index, fmt, separator);
}
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
- auto it = symbolInfoMap.find(name);
+ auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
- return it->getValue().getAllRangeUse(name, index, fmt, separator);
+ return it->second.getAllRangeUse(name, index, fmt, separator);
+}
+
+void SymbolInfoMap::assignUniqueAlternativeNames() {
+ llvm::StringSet<> usedNames;
+
+ for (auto symbolInfoIt = symbolInfoMap.begin();
+ symbolInfoIt != symbolInfoMap.end();) {
+ auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
+ auto startRange = range.first;
+ auto endRange = range.second;
+
+ auto operandName = symbolInfoIt->first;
+ int startSearchIndex = 0;
+ for (++startRange; startRange != endRange; ++startRange) {
+ // Current operand name is not unique, find a unique one
+ // and set the alternative name.
+ for (int i = startSearchIndex;; ++i) {
+ std::string alternativeName = operandName + std::to_string(i);
+ if (!usedNames.contains(alternativeName) &&
+ symbolInfoMap.count(alternativeName) == 0) {
+ usedNames.insert(alternativeName);
+ startRange->second.alternativeName = alternativeName;
+ startSearchIndex = i + 1;
+
+ break;
+ }
+ }
+ }
+
+ symbolInfoIt = endRange;
+ }
}
//===----------------------------------------------------------------------===//
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
+
+ LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
+ infoMap.assignUniqueAlternativeNames();
+ LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
let results = (outs I32);
}
+def OpN : TEST_Op<"op_n"> {
+ let arguments = (ins I32, I32);
+ let results = (outs I32);
+}
+
+def OpO : TEST_Op<"op_o"> {
+ let arguments = (ins I32);
+ let results = (outs I32);
+}
+
+def OpP : TEST_Op<"op_p"> {
+ let arguments = (ins I32, I32, I32, I32, I32, I32);
+ let results = (outs I32);
+}
+
+// Test same operand name enforces equality condition check.
+def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
+
+// Test when equality is enforced at different depth.
+def TestNestedOpEqualArgsPattern :
+ Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
+
+// Test multiple equal arguments check enforced.
+def TestMultipleEqualArgsPattern :
+ Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
return
}
+// CHECK-LABEL: verifyEqualArgs
+func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
+ // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
+
+ // CHECK: "test.op_o"(%arg0) : (i32) -> i32
+ "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
+
+ // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
+ "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
+
+ return
+}
+
+// CHECK-LABEL: verifyNestedOpEqualArgs
+func @verifyNestedOpEqualArgs(
+ %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
+ // def TestNestedOpEqualArgsPattern :
+ // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
+
+ // CHECK: %arg1
+ %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+ : (i32, i32, i32, i32, i32, i32) -> (i32)
+ %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+
+ // CHECK: test.op_p
+ // CHECK: test.op_n
+ %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+ : (i32, i32, i32, i32, i32, i32) -> (i32)
+ %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+
+ return
+}
+
+// CHECK-LABEL: verifyMultipleEqualArgs
+func @verifyMultipleEqualArgs(
+ %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
+ // def TestMultipleEqualArgsPattern :
+ // Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+
+ // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
+ "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
+ (i32, i32, i32, i32 , i32, i32) -> i32
+
+ // CHECK: test.op_p
+ "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
+ (i32, i32, i32, i32 , i32, i32) -> i32
+
+ // CHECK: test.op_p
+ "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
+ (i32, i32, i32, i32 , i32, i32) -> i32
+
+ // CHECK: test.op_p
+ "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
+ (i32, i32, i32, i32 , i32, i32) -> i32
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt);
+ // Emits C++ for checking a match with a corresponding match failure
+ // diagnostics.
+ void emitMatchCheck(int depth, const std::string &matchStr,
+ const std::string &failureStr);
+
//===--------------------------------------------------------------------===//
// Rewrite utilities
//===--------------------------------------------------------------------===//
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
- os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
- argIndex - numPrevAttrs);
+ auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
+ os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
+ res->second.getVarName(name), depth, argIndex - numPrevAttrs);
}
}
void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
- os << "if (!(" << matchFmt.str() << "))";
+ emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
+}
+
+void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
+ const std::string &failureStr) {
+ os << "if (!(" << matchStr << "))";
os.scope("{\n", "\n}\n").os
<< "return rewriter.notifyMatchFailure(op" << depth
- << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
+ << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
<< ";\n});";
}
constraint.getDescription()));
}
}
+
+ // Some of the operands could be bound to the same symbol name, we need
+ // to enforce equality constraint on those.
+ // TODO: we should be able to emit equality checks early
+ // and short circuit unnecessary work if vars are not equal.
+ for (auto symbolInfoIt = symbolInfoMap.begin();
+ symbolInfoIt != symbolInfoMap.end();) {
+ auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+ auto startRange = range.first;
+ auto endRange = range.second;
+
+ auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+ for (++startRange; startRange != endRange; ++startRange) {
+ auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+ emitMatchCheck(
+ depth,
+ formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+ formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+ secondOperand));
+ }
+
+ symbolInfoIt = endRange;
+ }
+
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
- StringRef symbol = symbolInfoPair.getKey();
- auto &info = symbolInfoPair.getValue();
+ const auto &symbol = symbolInfoPair.first;
+ const auto &info = symbolInfoPair.second;
+
os << info.getVarDecl(symbol);
}
// TODO: capture ops with consistent numbering so that it can be
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
range);
} else {
- os << formatv("tblgen_values.push_back(", varName);
+ os << formatv("tblgen_values.push_back(");
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));