};
} // end namespace llvm
-// Returns the bound symbol for the given op argument or op named `symbol`.
-//
-// Arguments and ops bound in the source pattern are grouped into a
-// transient `PatternState` struct. This struct can be accessed in both
-// `match()` and `rewrite()` via the local variable named as `s`.
-static Twine getBoundSymbol(const StringRef &symbol) {
- return Twine("s.") + symbol;
-}
-
// Gets the dynamic value pack's name by removing the index suffix from
// `symbol`. Returns `symbol` itself if it does not contain an index.
//
// Handle symbols bound to matched op arguments
auto srcArgIt = sourceArguments.find(symbol);
if (srcArgIt != sourceArguments.end())
- return getBoundSymbol(symbol).str();
+ return symbol;
// Handle symbols bound to matched op results
auto srcOpIt = sourceOps.find(name);
if (srcOpIt != sourceOps.end())
- return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(),
+ return formatValuePack("{0}->getResult({1})", symbol,
srcOpIt->second->getNumResults(), /*offset=*/0);
return {};
}
namespace {
class PatternEmitter {
public:
- static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
- raw_ostream &os);
-
-private:
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
- // Emits the match() method.
- void emitMatchMethod(DagNode tree);
+private:
+ // Emits the code for matching ops.
+ void emitMatchLogic(DagNode tree);
- // Collects all of the operations within the given dag tree.
- void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+ // Emits the code for rewriting ops.
+ void emitRewriteLogic();
- // Emits the rewrite() method.
- void emitRewriteMethod();
+ //===--------------------------------------------------------------------===//
+ // Match utilities
+ //===--------------------------------------------------------------------===//
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
- // Returns a unique name for an value of the given `op`.
- std::string getUniqueValueName(const Operator *op);
+ //===--------------------------------------------------------------------===//
+ // Rewrite utilities
+ //===--------------------------------------------------------------------===//
- // Entry point for handling a rewrite pattern rooted at `resultTree` and
+ // Entry point for handling a result pattern rooted at `resultTree` and
// dispatches to concrete handlers. The given tree is the `resultIndex`-th
// argument of the enclosing DAG.
- std::string handleRewritePattern(DagNode resultTree, int resultIndex,
- int depth);
+ std::string handleResultPattern(DagNode resultTree, int resultIndex,
+ int depth);
// Emits the C++ statement to replace the matched DAG with a value built via
// calling native C++ code.
- std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
+ std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
// Returns the C++ expression referencing the old value serving as the
// replacement.
// DAG `tree` has a specified name, the created op will be assigned to a
// variable of the given name. Otherwise, a unique name will be used as the
// result value name.
- std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
+ std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
// `patArgName` is used to bound the argument to the source pattern.
std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
+ //===--------------------------------------------------------------------===//
+ // General utilities
+ //===--------------------------------------------------------------------===//
+
+ // Collects all of the operations within the given dag tree.
+ void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+
+ // Returns a unique name for a value of the given `op`.
+ std::string getUniqueValueName(const Operator *op);
+
+ //===--------------------------------------------------------------------===//
+ // Symbol utilities
+ //===--------------------------------------------------------------------===//
+
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
// is already bound.
void addSymbol(StringRef symbol, int numValues);
unsigned nextValueId;
raw_ostream &os;
- // Format contexts containing placeholder substitutations for match().
- FmtContext matchCtx;
- // Format contexts containing placeholder substitutations for rewrite().
- FmtContext rewriteCtx;
+ // Format contexts containing placeholder substitutations.
+ FmtContext fmtCtx;
// Number of op processed.
int opCounter = 0;
symbolResolver(pattern.getSourcePatternBoundArgs(),
pattern.getSourcePatternBoundOps()),
nextValueId(0), os(os) {
- matchCtx.withBuilder("mlir::Builder(ctx)");
- rewriteCtx.withBuilder("rewriter");
+ fmtCtx.withBuilder("rewriter");
}
std::string PatternEmitter::handleConstantAttr(Attribute attr,
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
- return tgfmt(attr.getConstBuilderTemplate(),
- &rewriteCtx.withBuilder("rewriter"), value);
+ return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
}
// Helper function to match patterns.
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
- os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
+ os.indent(indent) << formatv("{0} = op{1};\n", name, depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent + 2)
- << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
- ++opCounter, depth + 1);
+ << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
os.indent(indent) << "}\n";
continue;
}
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
- &matchCtx.withSelf(self))
+ &fmtCtx.withSelf(self))
<< ")) return matchFailure();\n";
}
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
- os.indent(indent) << getBoundSymbol(name) << " = op" << depth
- << "->getOperand(" << index << ");\n";
+ os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth,
+ index);
}
}
os.indent(indent) << "{\n";
indent += 2;
os.indent(indent) << formatv(
- "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
+ "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
attr.getStorageType(), namedAttr->name);
// TODO(antiagainst): This should use getter method to avoid duplication.
if (attr.hasDefaultValueInitializer()) {
- os.indent(indent) << "if (!attr) attr = "
- << tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
+ os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
+ << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
attr.getDefaultValueInitializer())
<< ";\n";
} else if (attr.isOptional()) {
- // For a missing attribut that is optional according to definition, we
+ // For a missing attribute that is optional according to definition, we
// should just capature a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes.
} else {
- os.indent(indent) << "if (!attr) return matchFailure();\n";
+ os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
}
auto matcher = tree.getArgAsLeaf(index);
// check the constraint.
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
- &matchCtx.withSelf("attr"))
+ &fmtCtx.withSelf("tblgen_attr"))
<< ")) return matchFailure();\n";
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
- os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
+ os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
}
indent -= 2;
os.indent(indent) << "}\n";
}
-void PatternEmitter::emitMatchMethod(DagNode tree) {
- // Emit the heading.
- os << R"(
- PatternMatchResult match(Operation *op0) const override {
- auto ctx = op0->getContext(); (void)ctx;
- auto state = llvm::make_unique<MatchedState>();
- auto &s = *state;
- s.autogeneratedRewritePatternOps[0] = op0;
-)";
-
+void PatternEmitter::emitMatchLogic(DagNode tree) {
emitOpMatch(tree, 0);
for (auto &appliedConstraint : pattern.getConstraints()) {
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
os.indent(4) << formatv(cmd,
- tgfmt(condition, &matchCtx.withSelf(self.str())));
+ tgfmt(condition, &fmtCtx.withSelf(self.str())));
} else if (isa<AttrConstraint>(constraint)) {
PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd,
- tgfmt(condition, &matchCtx.withSelf(self),
- names[0], names[1], names[2], names[3]));
+ tgfmt(condition, &fmtCtx.withSelf(self), names[0],
+ names[1], names[2], names[3]));
}
}
-
- os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void PatternEmitter::collectOps(DagNode tree,
}
void PatternEmitter::emit(StringRef rewriteName) {
- // Get the DAG tree for the source pattern
- DagNode tree = pattern.getSourcePattern();
+ // Get the DAG tree for the source pattern.
+ DagNode sourceTree = pattern.getSourcePattern();
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
// Collect the set of result operations.
- llvm::SmallPtrSet<const Operator *, 4> results;
+ llvm::SmallPtrSet<const Operator *, 4> resultOps;
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
- collectOps(pattern.getResultPattern(i), results);
+ collectOps(pattern.getResultPattern(i), resultOps);
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
- {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
+ {0}(MLIRContext *context)
+ : RewritePattern("{1}", {{)",
rewriteName, rootName);
- interleaveComma(results, os, [&](const Operator *op) {
+ interleaveComma(resultOps, os, [&](const Operator *op) {
os << '"' << op->getOperationName() << '"';
});
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
- // Emit matched state.
- os << " struct MatchedState : public PatternState {\n";
+ // Emit matchAndRewrite() function.
+ os << R"(
+ PatternMatchResult matchAndRewrite(Operation *op0,
+ PatternRewriter &rewriter) const override {
+)";
+
+ os.indent(4) << "// Variables for capturing values and attributes used for "
+ "creating ops\n";
+ // Create local variables for storing the arguments bound to symbols.
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
auto fieldName = arg.first();
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
- os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
- << ";\n";
+ os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(),
+ fieldName);
} else {
os.indent(4) << "Value *" << fieldName << ";\n";
}
}
+ // Create local variables for storing the ops bound to symbols.
for (const auto &result : pattern.getSourcePatternBoundOps()) {
- os.indent(4) << "Operation *" << result.getKey() << ";\n";
+ os.indent(4) << formatv("Operation *{0};\n", result.getKey());
}
- // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
- // numbering so that it can be reused for fused loc.
- os.indent(4) << "Operation* autogeneratedRewritePatternOps["
- << pattern.getSourcePattern().getNumOps() << "];\n";
- os << " };\n";
+ // TODO(jpienaar): capture ops with consistent numbering so that it can be
+ // reused for fused loc.
+ os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
+ pattern.getSourcePattern().getNumOps());
- emitMatchMethod(tree);
- emitRewriteMethod();
+ os.indent(4) << "// Match\n";
+ os.indent(4) << "tblgen_ops[0] = op0;\n";
+ emitMatchLogic(sourceTree);
+ os << "\n";
+ os.indent(4) << "// Rewrite\n";
+ emitRewriteLogic();
+
+ os.indent(4) << "return matchSuccess();\n";
+ os << " };\n";
os << "};\n";
}
-void PatternEmitter::emitRewriteMethod() {
+void PatternEmitter::emitRewriteLogic() {
const Operator &rootOp = pattern.getSourceRootOp();
int numExpectedResults = rootOp.getNumResults();
int numResultPatterns = pattern.getNumResultPatterns();
PrintFatalError(loc, error);
}
- os << R"(
- void rewrite(Operation *op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const override {
- auto& s = *static_cast<MatchedState *>(state.get());
- auto loc = rewriter.getFusedLoc({)";
+ os.indent(4) << "auto loc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
- os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
- << "]->getLoc()";
+ os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
}
os << "}); (void)loc;\n";
llvm::SmallVector<std::string, 2> resultValues;
for (int i = 0; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
- resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0));
+ resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
}
// Emit the final replaceOp() statement
- os.indent(4) << "rewriter.replaceOp(op, {";
+ os.indent(4) << "rewriter.replaceOp(op0, {";
interleave(
ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
- os << "});\n }\n";
+ os << "});\n";
}
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
}
-std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
- int resultIndex, int depth) {
+std::string PatternEmitter::handleResultPattern(DagNode resultTree,
+ int resultIndex, int depth) {
if (resultTree.isNativeCodeCall())
- return emitReplaceWithNativeCodeCall(resultTree);
+ return handleReplaceWithNativeCodeCall(resultTree);
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree);
// Create the op and get the local variable for it.
- auto results = emitOpCreate(resultTree, resultIndex, depth);
+ auto results = handleOpCreation(resultTree, resultIndex, depth);
// We need to get all the values out of this local variable if we've created a
// multi-result op.
const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
return handleConstantAttr(enumCase, val);
}
pattern.ensureBoundInSourcePattern(argName);
- std::string result = getBoundSymbol(argName).str();
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
- return result;
+ return argName;
}
if (leaf.isNativeCodeCall()) {
- return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
+ return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
}
PrintFatalError(loc, "unhandled case when rewriting op");
}
-std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
+std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
auto fmt = tree.getNativeCodeTemplate();
// TODO(b/138794486): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
}
- return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
- attrs[4], attrs[5], attrs[6], attrs[7]);
+ return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
+ attrs[5], attrs[6], attrs[7]);
}
void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
return 1;
}
-std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
- int depth) {
+std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
+ int depth) {
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
// This happens in a recursive manner.
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) {
- childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
+ childNodeNames[i] = handleResultPattern(child, i, depth + 1);
}
}
// We need to specify the types for all results.
auto resultTypes =
- formatValuePack("op->getResult({1})->getType()", valuePackName,
+ formatValuePack("op0->getResult({1})->getType()", valuePackName,
resultOp.getNumResults(), resultIndex);
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
DagLeaf leaf = tree.getArgAsLeaf(i);
auto symbol = resolveSymbol(tree.getArgName(i));
if (leaf.isNativeCodeCall()) {
- os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
+ os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
os << symbol;
}
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
- emitReplaceWithNativeCodeCall(subTree));
+ handleReplaceWithNativeCodeCall(subTree));
} else {
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
return resultValue;
}
-void PatternEmitter::emit(StringRef rewriteName, Record *p,
- RecordOperatorMap *mapper, raw_ostream &os) {
- PatternEmitter(p, mapper, os).emit(rewriteName);
-}
-
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
} else {
name = p->getName();
}
- PatternEmitter::emit(name, p, &recordOpMap, os);
+ PatternEmitter(p, &recordOpMap, os).emit(name);
rewriterNames.push_back(std::move(name));
}