[mlir] Register the print-op-graph pass using ODS
authorJacques Pienaar <jpienaar@google.com>
Sat, 20 Feb 2021 23:42:02 +0000 (15:42 -0800)
committerJacques Pienaar <jpienaar@google.com>
Sat, 20 Feb 2021 23:42:02 +0000 (15:42 -0800)
Move over to ODS & use pass options.

mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/Transforms/print-op-graph.mlir [new file with mode: 0644]

index a03b439..925e7e7 100644 (file)
@@ -730,4 +730,22 @@ def SymbolDCE : Pass<"symbol-dce"> {
   }];
   let constructor = "mlir::createSymbolDCEPass()";
 }
+
+def ViewOpGraphPass : Pass<"symbol-dce", "ModuleOp"> {
+  let summary = "Print graphviz view of module";
+  let description = [{
+    This pass prints a graphviz per block of a module.
+
+    - Op are represented as nodes;
+    - Uses as edges;
+  }];
+  let constructor = "mlir::createPrintOpGraphPass()";
+  let options = [
+    Option<"title", "title", "std::string",
+           /*default=*/"", "The prefix of the title of the graph">,
+    Option<"shortNames", "short-names", "bool", /*default=*/"false",
+           "Use short names">
+  ];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
index 97fe7b2..3d52d79 100644 (file)
@@ -104,10 +104,12 @@ namespace {
 // PrintOpPass is simple pass to write graph per function.
 // Note: this is a module pass only to avoid interleaving on the same ostream
 // due to multi-threading over functions.
-struct PrintOpPass : public PrintOpBase<PrintOpPass> {
-  explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false,
-                       const Twine &title = "")
-      : os(os), title(title.str()), short_names(short_names) {}
+class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
+public:
+  PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) {
+    this->shortNames = shortNames;
+    this->title = title.str();
+  }
 
   std::string getOpName(Operation &op) {
     auto symbolAttr =
@@ -133,7 +135,7 @@ struct PrintOpPass : public PrintOpBase<PrintOpPass> {
           auto blockName = llvm::hasSingleElement(region)
                                ? ""
                                : ("__" + llvm::utostr(indexed_block.index()));
-          llvm::WriteGraph(os, &indexed_block.value(), short_names,
+          llvm::WriteGraph(os, &indexed_block.value(), shortNames,
                            Twine(title) + opName + blockName);
         }
       }
@@ -144,9 +146,7 @@ struct PrintOpPass : public PrintOpBase<PrintOpPass> {
 
 private:
   raw_ostream &os;
-  std::string title;
   int unnamedOpCtr = 0;
-  bool short_names;
 };
 } // namespace
 
diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
new file mode 100644 (file)
index 0000000..1c4548e
--- /dev/null
@@ -0,0 +1,12 @@
+// RUN: mlir-opt -allow-unregistered-dialect -print-op-graph %s -o %t 2>&1 | FileCheck %s
+
+// CHECK-LABEL: digraph "merge_blocks"
+func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
+  %0:2 = "test.merge_blocks"() ({
+  ^bb0:
+     "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> ()
+  ^bb1(%arg3 : i32, %arg4 : i32):
+     "test.return"(%arg3, %arg4) : (i32, i32) -> ()
+  }) : () -> (i32, i32)
+  "test.return"(%0#0, %0#1) : (i32, i32) -> ()
+}