Add support to RewritePattern for specifying the potential operations that can...
authorRiver Riddle <riverriddle@google.com>
Sat, 25 May 2019 02:35:23 +0000 (19:35 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:02:42 +0000 (20:02 -0700)
--

PiperOrigin-RevId: 249936309

mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/TableGen/Pattern.cpp
mlir/test/mlir-tblgen/pattern-benefit.td
mlir/test/mlir-tblgen/pattern.td
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 340e844..60c8255 100644 (file)
@@ -184,12 +184,25 @@ public:
     return matchFailure();
   }
 
+  /// Return a list of operations that may be generated when rewriting an
+  /// operation instance with this pattern.
+  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
+
 protected:
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching.
   RewritePattern(StringRef rootName, PatternBenefit benefit,
                  MLIRContext *context)
       : Pattern(rootName, benefit, context) {}
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching. They can also specify
+  /// the names of operations that may be generated during a successful rewrite.
+  RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
+                 PatternBenefit benefit, MLIRContext *context);
+
+  /// A list of the potential operations that may be generated when rewriting
+  /// an op with this pattern.
+  llvm::SmallVector<OperationName, 2> generatedOps;
 };
 
 //===----------------------------------------------------------------------===//
index f5eb9a3..79d7e98 100644 (file)
@@ -172,6 +172,9 @@ public:
   // Returns true if this DAG node is wrapping native code call.
   bool isNativeCodeCall() const;
 
+  // Returns true if this DAG node is an operation.
+  bool isOperation() const;
+
   // Returns the native code call template inside this DAG node.
   // Precondition: isNativeCodeCall()
   llvm::StringRef getNativeCodeTemplate() const;
index 1943073..ac23851 100644 (file)
@@ -60,6 +60,20 @@ PatternMatchResult RewritePattern::match(Operation *op) const {
   llvm_unreachable("need to implement either match or matchAndRewrite!");
 }
 
+/// Patterns must specify the root operation name they match against, and can
+/// also specify the benefit of the pattern matching. They can also specify the
+/// names of operations that may be generated during a successful rewrite.
+RewritePattern::RewritePattern(StringRef rootName,
+                               ArrayRef<StringRef> generatedNames,
+                               PatternBenefit benefit, MLIRContext *context)
+    : Pattern(rootName, benefit, context) {
+  generatedOps.reserve(generatedNames.size());
+  std::transform(generatedNames.begin(), generatedNames.end(),
+                 std::back_inserter(generatedOps), [context](StringRef name) {
+                   return OperationName(name, context);
+                 });
+}
+
 PatternRewriter::~PatternRewriter() {
   // Out of line to provide a vtable anchor for the class.
 }
index 31bab81..e2ddcba 100644 (file)
@@ -94,6 +94,10 @@ bool tblgen::DagNode::isNativeCodeCall() const {
   return false;
 }
 
+bool tblgen::DagNode::isOperation() const {
+  return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue());
+}
+
 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
   return cast<llvm::DefInit>(node->getOperator())
index 61db84b..36bc2c7 100644 (file)
@@ -26,9 +26,9 @@ def Z_AddOp : NS_Op<"add"> {
 def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
 
 // CHECK-LABEL: struct bena
-// CHECK: RewritePattern("x.add", 2, context) {}
+// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {}
 
 def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
 
 // CHECK-LABEL: struct benb
-// CHECK: RewritePattern("x.add", 101, context) {}
+// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {}
index b5a6c60..66ff381 100644 (file)
@@ -34,7 +34,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 
 // CHECK: struct GeneratedConvert0 : public RewritePattern
 
-// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
+// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {}
 
 // CHECK: struct MatchedState : public PatternState {
 // CHECK:   Value *input;
index c29cf54..9103cb0 100644 (file)
@@ -153,6 +153,9 @@ private:
   // Emits the match() method.
   void emitMatchMethod(DagNode tree);
 
+  // Collects all of the operations within the given dag tree.
+  void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+
   // Emits the rewrite() method.
   void emitRewriteMethod();
 
@@ -443,6 +446,18 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
 }
 
+void PatternEmitter::collectOps(DagNode tree,
+                                llvm::SmallPtrSetImpl<const Operator *> &ops) {
+  // Check if this tree is an operation.
+  if (tree.isOperation())
+    ops.insert(&tree.getDialectOp(opMap));
+
+  // Recurse the arguments of the tree.
+  for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
+    if (auto child = tree.getArgAsNestedDag(i))
+      collectOps(child, ops);
+}
+
 void PatternEmitter::emit(StringRef rewriteName) {
   // Get the DAG tree for the source pattern
   DagNode tree = pattern.getSourcePattern();
@@ -454,14 +469,22 @@ void PatternEmitter::emit(StringRef rewriteName) {
     PrintFatalError(
         loc, "replacing op with variadic results not supported right now");
 
+  // Collect the set of result operations.
+  llvm::SmallPtrSet<const Operator *, 4> results;
+  for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
+    collectOps(pattern.getResultPattern(i), results);
+
   // Emit RewritePattern for Pattern.
   auto locs = pattern.getLocation();
   os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
                 make_range(locs.rbegin(), locs.rend()));
   os << formatv(R"(struct {0} : public RewritePattern {
-  {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
-                rewriteName, rootName, pattern.getBenefit())
-     << "\n";
+  {0}(MLIRContext *context) : RewritePattern("{1}", {{)",
+                rewriteName, rootName);
+  interleaveComma(results, os, [&](const Operator *op) {
+    os << '"' << op->getOperationName() << '"';
+  });
+  os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
 
   // Emit matched state.
   os << "  struct MatchedState : public PatternState {\n";