return matchFailure();
}
+ /// Return a list of operations that may be generated when rewriting an
+ /// operation instance with this pattern.
+ ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
+
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
+ /// Patterns must specify the root operation name they match against, and can
+ /// also specify the benefit of the pattern matching. They can also specify
+ /// the names of operations that may be generated during a successful rewrite.
+ RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
+ PatternBenefit benefit, MLIRContext *context);
+
+ /// A list of the potential operations that may be generated when rewriting
+ /// an op with this pattern.
+ llvm::SmallVector<OperationName, 2> generatedOps;
};
//===----------------------------------------------------------------------===//
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
+ // Returns true if this DAG node is an operation.
+ bool isOperation() const;
+
// Returns the native code call template inside this DAG node.
// Precondition: isNativeCodeCall()
llvm::StringRef getNativeCodeTemplate() const;
llvm_unreachable("need to implement either match or matchAndRewrite!");
}
+/// Patterns must specify the root operation name they match against, and can
+/// also specify the benefit of the pattern matching. They can also specify the
+/// names of operations that may be generated during a successful rewrite.
+RewritePattern::RewritePattern(StringRef rootName,
+ ArrayRef<StringRef> generatedNames,
+ PatternBenefit benefit, MLIRContext *context)
+ : Pattern(rootName, benefit, context) {
+ generatedOps.reserve(generatedNames.size());
+ std::transform(generatedNames.begin(), generatedNames.end(),
+ std::back_inserter(generatedOps), [context](StringRef name) {
+ return OperationName(name, context);
+ });
+}
+
PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
}
return false;
}
+bool tblgen::DagNode::isOperation() const {
+ return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue());
+}
+
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(node->getOperator())
def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
// CHECK-LABEL: struct bena
-// CHECK: RewritePattern("x.add", 2, context) {}
+// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {}
def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
// CHECK-LABEL: struct benb
-// CHECK: RewritePattern("x.add", 101, context) {}
+// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {}
// CHECK: struct GeneratedConvert0 : public RewritePattern
-// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
+// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {}
// CHECK: struct MatchedState : public PatternState {
// CHECK: Value *input;
// Emits the match() method.
void emitMatchMethod(DagNode tree);
+ // Collects all of the operations within the given dag tree.
+ void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+
// Emits the rewrite() method.
void emitRewriteMethod();
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
+void PatternEmitter::collectOps(DagNode tree,
+ llvm::SmallPtrSetImpl<const Operator *> &ops) {
+ // Check if this tree is an operation.
+ if (tree.isOperation())
+ ops.insert(&tree.getDialectOp(opMap));
+
+ // Recurse the arguments of the tree.
+ for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
+ if (auto child = tree.getArgAsNestedDag(i))
+ collectOps(child, ops);
+}
+
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
PrintFatalError(
loc, "replacing op with variadic results not supported right now");
+ // Collect the set of result operations.
+ llvm::SmallPtrSet<const Operator *, 4> results;
+ for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
+ collectOps(pattern.getResultPattern(i), results);
+
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
- {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
- rewriteName, rootName, pattern.getBenefit())
- << "\n";
+ {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
+ rewriteName, rootName);
+ interleaveComma(results, os, [&](const Operator *op) {
+ os << '"' << op->getOperationName() << '"';
+ });
+ os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";