[TableGen] Support naming rewrite rules
authorLei Zhang <antiagainst@google.com>
Wed, 10 Apr 2019 18:37:53 +0000 (11:37 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 11 Apr 2019 17:52:43 +0000 (10:52 -0700)
--

PiperOrigin-RevId: 242909061

mlir/test/mlir-tblgen/pattern.td
mlir/tools/mlir-tblgen/RewriterGen.cpp

index bea44c9..c3c4b20 100644 (file)
@@ -12,6 +12,13 @@ def OpB : Op<"op_b", []> {
   let results = (outs I32:$result);
 }
 
+def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+
+// Test rewrite rule naming
+// ---
+
+// CHECK: struct MyRule : public RewritePattern
+
 def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 
 // Test basic structure generated from Pattern
@@ -31,5 +38,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 // CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> state,
 // CHECK:              PatternRewriter &rewriter) const override
 
+
 // CHECK: void populateWithGenerated(MLIRContext *context, OwningRewritePatternList *patterns)
+// CHECK:   patterns->push_back(llvm::make_unique<MyRule>(context));
 // CHECK:   patterns->push_back(llvm::make_unique<GeneratedConvert0>(context));
index 3b3dfd3..ae99a05 100644 (file)
@@ -774,25 +774,38 @@ void PatternEmitter::emit(StringRef rewriteName, Record *p,
 
 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Rewriters", os);
+
   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+  auto numPatterns = patterns.size();
 
   // We put the map here because it can be shared among multiple patterns.
   RecordOperatorMap recordOpMap;
 
-  // Ensure unique patterns simply by appending unique suffix.
-  std::string baseRewriteName = "GeneratedConvert";
-  int rewritePatternCount = 0;
+  std::vector<std::string> rewriterNames;
+  rewriterNames.reserve(numPatterns);
+
+  std::string baseRewriterName = "GeneratedConvert";
+  int rewriterIndex = 0;
+
   for (Record *p : patterns) {
-    PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
-                         p, &recordOpMap, os);
+    std::string name;
+    if (p->isAnonymous()) {
+      // If no name is provided, ensure unique rewriter names simply by
+      // appending unique suffix.
+      name = baseRewriterName + llvm::utostr(rewriterIndex++);
+    } else {
+      name = p->getName();
+    }
+    PatternEmitter::emit(name, p, &recordOpMap, os);
+    rewriterNames.push_back(std::move(name));
   }
 
   // Emit function to add the generated matchers to the pattern list.
   os << "void populateWithGenerated(MLIRContext *context, "
      << "OwningRewritePatternList *patterns) {\n";
-  for (unsigned i = 0; i != rewritePatternCount; ++i) {
-    os.indent(2) << "patterns->push_back(llvm::make_unique<" << baseRewriteName
-                 << i << ">(context));\n";
+  for (const auto &name : rewriterNames) {
+    os << "  patterns->push_back(llvm::make_unique<" << name
+       << ">(context));\n";
   }
   os << "}\n";
 }