Add op stats pass to mlir-opt.
authorJacques Pienaar <jpienaar@google.com>
Tue, 20 Nov 2018 17:38:15 +0000 (09:38 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:02:46 +0000 (14:02 -0700)
op-stats pass currently returns the number of occurrences of different operations in a Module. Useful for verifying transformation properties (e.g., 3 ops of specific dialect, 0 of another), but probably not useful outside of that so keeping it local to mlir-opt. This does not consider op attributes when counting.

PiperOrigin-RevId: 222259727

mlir/lib/Transforms/CFGFunctionViewGraph.cpp
mlir/test/IR/op-stats.mlir [new file with mode: 0644]
mlir/tools/mlir-opt/OpStats.cpp [new file with mode: 0644]

index d29708d4fef10b2a71a44334707e63dbccb5c1fb..49aa2876e526d25d06e34fa5dae4a889ead6469c 100644 (file)
@@ -1,4 +1,4 @@
-//===- CFGFunctionViewGraph.h - View/write graphviz graphs ------*- C++ -*-===//
+//===- CFGFunctionViewGraph.cpp - View/write graphviz graphs --------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
diff --git a/mlir/test/IR/op-stats.mlir b/mlir/test/IR/op-stats.mlir
new file mode 100644 (file)
index 0000000..27674ef
--- /dev/null
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -print-op-stats %s -o=/dev/null 2>&1 | FileCheck %s
+
+cfgfunc @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
+bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
+  %0 = addf %arg0, %arg1 : tensor<4xf32>
+  %1 = addf %arg0, %arg1 : tensor<4xf32>
+  %2 = addf %arg0, %arg1 : tensor<4xf32>
+  %3 = addf %arg0, %arg1 : tensor<4xf32>
+  %4 = addf %arg0, %arg1 : tensor<4xf32>
+  %5 = addf %arg0, %arg1 : tensor<4xf32>
+  %10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
+// CHECK-LABEL: Operations encountered
+// CHECK: 'addf' , 6
+// CHECK: 'long_op_name' , 1
+// CHECK: 'return' , 1
+// CHECK: 'xla.add' , 17
diff --git a/mlir/tools/mlir-opt/OpStats.cpp b/mlir/tools/mlir-opt/OpStats.cpp
new file mode 100644 (file)
index 0000000..0c9514a
--- /dev/null
@@ -0,0 +1,125 @@
+//===- OpStats.cpp - Prints stats of operations in module -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Pass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
+  explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
+      : FunctionPass(&PrintOpStatsPass::passID), os(os) {}
+
+  // Prints the resultant operation stats post iterating over the module.
+  PassResult runOnModule(Module *m) override;
+
+  // Process CFG function considering the instructions in basic blocks.
+  PassResult runOnCFGFunction(CFGFunction *function) override;
+
+  // Process ML functions and operation statments in ML functions.
+  PassResult runOnMLFunction(MLFunction *function) override;
+  void visitOperationStmt(OperationStmt *stmt);
+
+  // Print summary of op stats.
+  void printSummary();
+
+  static char passID;
+
+private:
+  llvm::StringMap<int64_t> opCount;
+
+  llvm::raw_ostream &os;
+};
+} // namespace
+
+char PrintOpStatsPass::passID = 0;
+
+PassResult PrintOpStatsPass::runOnModule(Module *m) {
+  auto result = FunctionPass::runOnModule(m);
+  if (!result)
+    printSummary();
+  return result;
+}
+
+PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) {
+  for (const auto &bb : *function)
+    for (const auto &inst : bb)
+      ++opCount[inst.getName().getStringRef()];
+  return success();
+}
+
+void PrintOpStatsPass::visitOperationStmt(OperationStmt *stmt) {
+  ++opCount[stmt->getName().getStringRef()];
+}
+
+PassResult PrintOpStatsPass::runOnMLFunction(MLFunction *function) {
+  walk(function);
+  return success();
+}
+
+void PrintOpStatsPass::printSummary() {
+  os << "Operations encountered:\n";
+  os << "-----------------------\n";
+  std::vector<StringRef> sorted(opCount.keys().begin(), opCount.keys().end());
+  llvm::sort(sorted);
+
+  // Returns the lenght of the dialect prefix of an op.
+  auto dialectLen = [](StringRef opName) -> size_t {
+    auto dialectEnd = opName.find_last_of('.');
+    if (dialectEnd == StringRef::npos)
+      return 0;
+    // Count the periond too.
+    return dialectEnd + 1;
+  };
+
+  // Left-align the names (aligning on the dialect) and right-align count below.
+  // The alignment is for readability and does not affect CSV/FileCheck parsing.
+  size_t maxLenName = 0;
+  size_t maxLenNamePrefixLen = 0;
+  size_t maxLenDialect = 0;
+  int maxLenCount = 0;
+  for (const auto &key : sorted) {
+    size_t len = key.size();
+    size_t prefix = dialectLen(key);
+    if (len > maxLenName) {
+      maxLenName = len;
+      maxLenNamePrefixLen = prefix;
+    }
+    maxLenDialect = max(maxLenDialect, prefix);
+    // This takes advantage of the fact that opCount[key] > 0.
+    maxLenCount = max(maxLenCount, (int)log10(opCount[key]) + 1);
+  }
+  // Adjust the max name length to account for the dialect.
+  maxLenName += (maxLenDialect - maxLenNamePrefixLen);
+
+  for (const auto &key : sorted) {
+    size_t prefix = maxLenDialect - dialectLen(key);
+    os.indent(2 + prefix) << '\'' << key << '\'';
+    os.indent(maxLenName - key.size() - prefix) << " ,";
+    os.indent(maxLenCount - (int)log10(opCount[key])) << opCount[key] << "\n";
+  }
+}
+
+static PassRegistration<PrintOpStatsPass>
+    pass("print-op-stats", "Print statistics of operations");