[mlir] add an option to print op stats in JSON
authorOkwan Kwon <okwan@google.com>
Tue, 14 Jun 2022 21:16:26 +0000 (14:16 -0700)
committerOkwan Kwon <okwan@google.com>
Wed, 15 Jun 2022 17:07:36 +0000 (10:07 -0700)
Differential Revision: https://reviews.llvm.org/D127691

mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/OpStats.cpp
mlir/test/CAPI/pass.c
mlir/test/IR/op-stats-json.mlir [new file with mode: 0644]
mlir/test/python/pass_manager.py

index 7ac71ce..8b8e6a1 100644 (file)
@@ -145,6 +145,10 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
 def PrintOpStats : Pass<"print-op-stats"> {
   let summary = "Print statistics of operations";
   let constructor = "mlir::createPrintOpStatsPass()";
+  let options = [
+    Option<"printAsJSON", "json", "bool", /*default=*/"false",
+           "print the stats as JSON">
+  ];
 }
 
 def SCCP : Pass<"sccp"> {
index 8adc411..e7740ab 100644 (file)
@@ -27,6 +27,9 @@ struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
   // Print summary of op stats.
   void printSummary();
 
+  // Print symmary of op stats in JSON.
+  void printSummaryInJSON();
+
 private:
   llvm::StringMap<int64_t> opCount;
   raw_ostream &os;
@@ -37,8 +40,12 @@ void PrintOpStatsPass::runOnOperation() {
   opCount.clear();
 
   // Compute the operation statistics for the currently visited operation.
-  getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
-  printSummary();
+  getOperation()->walk(
+      [&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+  if (printAsJSON) {
+    printSummaryInJSON();
+  } else
+    printSummary();
 }
 
 void PrintOpStatsPass::printSummary() {
@@ -80,6 +87,23 @@ void PrintOpStatsPass::printSummary() {
   }
 }
 
+void PrintOpStatsPass::printSummaryInJSON() {
+  SmallVector<StringRef, 64> sorted(opCount.keys());
+  llvm::sort(sorted);
+
+  os << "{\n";
+
+  for (unsigned i = 0, e = sorted.size(); i != e; ++i) {
+    const auto &key = sorted[i];
+    os << "  \"" << key << "\" : " << opCount[key];
+    if (i != e - 1)
+      os << ",\n";
+    else
+      os << "\n";
+  }
+  os << "}\n";
+}
+
 std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) {
   return std::make_unique<PrintOpStatsPass>(os);
 }
index 63aba29..f73398e 100644 (file)
@@ -138,14 +138,14 @@ void testPrintPassPipeline() {
   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
 
   // Print the top level pass manager
-  // CHECK: Top-level: builtin.module(func.func(print-op-stats))
+  // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false}))
   fprintf(stderr, "Top-level: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);
   fprintf(stderr, "\n");
 
   // Print the pipeline nested one level down
-  // CHECK: Nested Module: func.func(print-op-stats)
+  // CHECK: Nested Module: func.func(print-op-stats{json=false})
   fprintf(stderr, "Nested Module: ");
   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
   fprintf(stderr, "\n");
@@ -166,8 +166,9 @@ void testParsePassPipeline() {
   // Try parse a pipeline.
   MlirLogicalResult status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
-      mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
-                                     " func.func(print-op-stats))"));
+      mlirStringRefCreateFromCString(
+          "builtin.module(func.func(print-op-stats{json=false}),"
+          " func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsSuccess(status)) {
     fprintf(
@@ -179,8 +180,9 @@ void testParsePassPipeline() {
   mlirRegisterTransformsPrintOpStats();
   status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
-      mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
-                                     " func.func(print-op-stats))"));
+      mlirStringRefCreateFromCString(
+          "builtin.module(func.func(print-op-stats{json=false}),"
+          " func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsFailure(status)) {
     fprintf(stderr,
@@ -188,8 +190,8 @@ void testParsePassPipeline() {
     exit(EXIT_FAILURE);
   }
 
-  // CHECK: Round-trip: builtin.module(func.func(print-op-stats),
-  // func.func(print-op-stats))
+  // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}),
+  // func.func(print-op-stats{json=false}))
   fprintf(stderr, "Round-trip: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);
diff --git a/mlir/test/IR/op-stats-json.mlir b/mlir/test/IR/op-stats-json.mlir
new file mode 100644 (file)
index 0000000..40b0602
--- /dev/null
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats=json %s -o=/dev/null 2>&1 | FileCheck %s
+
+func.func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
+^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
+  %0 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %1 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %2 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %3 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %4 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %5 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
+// CHECK: {
+// CHECK:   "arith.addf" : 6,
+// CHECK:   "func.return" : 1,
+// CHECK:   "long_op_name" : 1,
+// CHECK:   "xla.add" : 17
+// CHECK: }
index c046bb8..6cc627d 100644 (file)
@@ -36,19 +36,19 @@ def testParseSuccess():
     # A first import is expected to fail because the pass isn't registered
     # until we import mlir.transforms
     try:
-      pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
+      pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
       # TODO: this error should be propagate to Python but the C API does not help right now.
       # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
     except ValueError as e:
-      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats))'.
+      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'.
       log("ValueError exception:", e)
     else:
       log("Exception not produced")
 
     # This will register the pass and round-trip should be possible now.
     import mlir.transforms
-    pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
-    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats))
+    pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
+    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
     log("Roundtrip: ", pm)
 run(testParseSuccess)
 
@@ -86,7 +86,7 @@ run(testInvalidNesting)
 # CHECK-LABEL: TEST: testRun
 def testRunPipeline():
   with Context():
-    pm = PassManager.parse("print-op-stats")
+    pm = PassManager.parse("print-op-stats{json=false}")
     module = Module.parse(r"""func.func @successfulParse() { return }""")
     pm.run(module)
 # CHECK: Operations encountered: