NFC: Refactor the PassInstrumentation framework to operate on Operation instead of...
authorRiver Riddle <riverriddle@google.com>
Sat, 17 Aug 2019 00:59:03 +0000 (17:59 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 17 Aug 2019 00:59:37 +0000 (17:59 -0700)
Now that functions and modules are operations, Operation makes more sense as the opaque object to refer to both.

PiperOrigin-RevId: 263883913

mlir/g3doc/WritingAPass.md
mlir/include/mlir/Pass/PassInstrumentation.h
mlir/lib/Pass/IRPrinting.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassTiming.cpp

index dc06ace..47e57df 100644 (file)
@@ -389,8 +389,7 @@ struct DominanceCounterInstrumentation : public PassInstrumentation {
   unsigned &count;
 
   DominanceCounterInstrumentation(unsigned &count) : count(count) {}
-  void runAfterAnalysis(llvm::StringRef, AnalysisID *id,
-                        const llvm::Any &) override {
+  void runAfterAnalysis(llvm::StringRef, AnalysisID *id, Operation *) override {
     if (id == AnalysisID::getID<DominanceInfo>())
       ++count;
   }
index 4035832..46df6fd 100644 (file)
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/Any.h"
 #include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 using AnalysisID = ClassID;
+class Operation;
 class Pass;
 
 namespace detail {
@@ -39,32 +39,32 @@ public:
   virtual ~PassInstrumentation() = 0;
 
   /// A callback to run before a pass is executed. This function takes a pointer
-  /// to the pass to be executed, as well as an llvm::Any holding a pointer to
-  /// the IR unit being transformed on.
-  virtual void runBeforePass(Pass *pass, const llvm::Any &ir) {}
+  /// to the pass to be executed, as well as the current operation being
+  /// operated on.
+  virtual void runBeforePass(Pass *pass, Operation *op) {}
 
   /// A callback to run after a pass is successfully executed. This function
-  /// takes a pointer to the pass to be executed, as well as an llvm::Any
-  /// holding a pointer to the IR unit being transformed on.
-  virtual void runAfterPass(Pass *pass, const llvm::Any &ir) {}
+  /// takes a pointer to the pass to be executed, as well as the current
+  /// operation being operated on.
+  virtual void runAfterPass(Pass *pass, Operation *op) {}
 
   /// A callback to run when a pass execution fails. This function takes a
-  /// pointer to the pass that was being executed, as well as an llvm::Any
-  /// holding a pointer to the IR unit that was being transformed. Note
-  /// that the ir unit may be in an invalid state.
-  virtual void runAfterPassFailed(Pass *pass, const llvm::Any &ir) {}
+  /// pointer to the pass that was being executed, as well as the current
+  /// operation being operated on. Note that the operation may be in an invalid
+  /// state.
+  virtual void runAfterPassFailed(Pass *pass, Operation *op) {}
 
   /// A callback to run before an analysis is computed. This function takes the
-  /// name of the analysis to be computed, its AnalysisID, as well as an
-  /// llvm::Any holding a pointer to the IR unit being analyzed on.
+  /// name of the analysis to be computed, its AnalysisID, as well as the
+  /// current operation being analyzed.
   virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
-                                 const llvm::Any &ir) {}
+                                 Operation *op) {}
 
   /// A callback to run before an analysis is computed. This function takes the
-  /// name of the analysis that was computed, its AnalysisID, as well as an
-  /// llvm::Any holding a pointer to the IR unit that was analyzed.
+  /// name of the analysis that was computed, its AnalysisID, as well as the
+  /// current operation being analyzed.
   virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
-                                const llvm::Any &ir) {}
+                                Operation *op) {}
 };
 
 /// This class holds a collection of PassInstrumentation objects, and invokes
@@ -77,54 +77,25 @@ public:
   ~PassInstrumentor();
 
   /// See PassInstrumentation::runBeforePass for details.
-  template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
-    runBeforePass(pass, llvm::Any(ir));
-  }
+  void runBeforePass(Pass *pass, Operation *op);
 
   /// See PassInstrumentation::runAfterPass for details.
-  template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
-    runAfterPass(pass, llvm::Any(ir));
-  }
+  void runAfterPass(Pass *pass, Operation *op);
 
   /// See PassInstrumentation::runAfterPassFailed for details.
-  template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
-    runAfterPassFailed(pass, llvm::Any(ir));
-  }
+  void runAfterPassFailed(Pass *pass, Operation *op);
 
   /// See PassInstrumentation::runBeforeAnalysis for details.
-  template <typename IRUnitT>
-  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
-    runBeforeAnalysis(name, id, llvm::Any(ir));
-  }
+  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
 
   /// See PassInstrumentation::runAfterAnalysis for details.
-  template <typename IRUnitT>
-  void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
-    runAfterAnalysis(name, id, llvm::Any(ir));
-  }
+  void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
 
   /// Add the given instrumentation to the collection. This takes ownership over
   /// the given pointer.
   void addInstrumentation(PassInstrumentation *pi);
 
 private:
-  /// See PassInstrumentation::runBeforePass for details.
-  void runBeforePass(Pass *pass, const llvm::Any &ir);
-
-  /// See PassInstrumentation::runAfterPass for details.
-  void runAfterPass(Pass *pass, const llvm::Any &ir);
-
-  /// See PassInstrumentation::runAfterPassFailed for details.
-  void runAfterPassFailed(Pass *pass, const llvm::Any &ir);
-
-  /// See PassInstrumentation::runBeforeAnalysis for details.
-  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
-                         const llvm::Any &ir);
-
-  /// See PassInstrumentation::runAfterAnalysis for details.
-  void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
-                        const llvm::Any &ir);
-
   std::unique_ptr<detail::PassInstrumentorImpl> impl;
 };
 
index 2de4b05..bc66197 100644 (file)
@@ -27,8 +27,8 @@ using namespace mlir::detail;
 namespace {
 class IRPrinterInstrumentation : public PassInstrumentation {
 public:
-  /// A filter function to decide if the given ir should be printed. Returns
-  /// true if the ir should be printed, false otherwise.
+  /// A filter function to decide if the given pass should be printed. Returns
+  /// true if the pass should be printed, false otherwise.
   using ShouldPrintFn = std::function<bool(Pass *)>;
 
   IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass,
@@ -43,9 +43,9 @@ public:
 
 private:
   /// Instrumentation hooks.
-  void runBeforePass(Pass *pass, const llvm::Any &ir) override;
-  void runAfterPass(Pass *pass, const llvm::Any &ir) override;
-  void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override;
+  void runBeforePass(Pass *pass, Operation *op) override;
+  void runAfterPass(Pass *pass, Operation *op) override;
+  void runAfterPassFailed(Pass *pass, Operation *op) override;
 
   /// Filter functions for before and after pass execution.
   ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass;
@@ -63,12 +63,10 @@ static bool isHiddenPass(Pass *pass) {
   return isAdaptorPass(pass) || isVerifierPass(pass);
 }
 
-static void printIR(const llvm::Any &ir, bool printModuleScope,
-                    raw_ostream &out) {
+static void printIR(Operation *op, bool printModuleScope, raw_ostream &out) {
   // Check for printing at module scope.
-  if (printModuleScope && llvm::any_isa<FuncOp>(ir)) {
-    FuncOp function = llvm::any_cast<FuncOp>(ir);
-
+  auto function = dyn_cast<FuncOp>(op);
+  if (printModuleScope && function) {
     // Print the function name and a newline before the Module.
     out << " (function: " << function.getName() << ")\n";
     function.getParentOfType<ModuleOp>().print(out);
@@ -79,45 +77,44 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
   out << "\n";
 
   // Print the given function.
-  if (llvm::any_isa<FuncOp>(ir)) {
-    llvm::any_cast<FuncOp>(ir).print(out);
+  if (function) {
+    function.print(out);
     return;
   }
 
   // Print the given module.
-  assert(llvm::any_isa<ModuleOp>(ir) && "unexpected IR unit");
-  llvm::any_cast<ModuleOp>(ir).print(out);
+  assert(isa<ModuleOp>(op) && "unexpected IR unit");
+  cast<ModuleOp>(op).print(out);
 }
 
 /// Instrumentation hooks.
-void IRPrinterInstrumentation::runBeforePass(Pass *pass, const llvm::Any &ir) {
-  // Skip adaptor passes and passes that the user filtered out.
+void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
+  // Skip hidden passes and passes that the user filtered out.
   if (!shouldPrintBeforePass || isHiddenPass(pass) ||
       !shouldPrintBeforePass(pass))
     return;
   out << formatv("*** IR Dump Before {0} ***", pass->getName());
-  printIR(ir, printModuleScope, out);
+  printIR(op, printModuleScope, out);
   out << "\n\n";
 }
 
-void IRPrinterInstrumentation::runAfterPass(Pass *pass, const llvm::Any &ir) {
-  // Skip adaptor passes and passes that the user filtered out.
+void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
+  // Skip hidden passes and passes that the user filtered out.
   if (!shouldPrintAfterPass || isHiddenPass(pass) ||
       !shouldPrintAfterPass(pass))
     return;
   out << formatv("*** IR Dump After {0} ***", pass->getName());
-  printIR(ir, printModuleScope, out);
+  printIR(op, printModuleScope, out);
   out << "\n\n";
 }
 
-void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass,
-                                                  const llvm::Any &ir) {
+void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
   // Skip adaptor passes and passes that the user filtered out.
   if (!shouldPrintAfterPass || isAdaptorPass(pass) ||
       !shouldPrintAfterPass(pass))
     return;
   out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
-  printIR(ir, printModuleScope, out);
+  printIR(op, printModuleScope, out);
   out << "\n\n";
 }
 
index 35d9663..ba3b474 100644 (file)
@@ -393,40 +393,40 @@ PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
 PassInstrumentor::~PassInstrumentor() {}
 
 /// See PassInstrumentation::runBeforePass for details.
-void PassInstrumentor::runBeforePass(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : impl->instrumentations)
-    instr->runBeforePass(pass, ir);
+    instr->runBeforePass(pass, op);
 }
 
 /// See PassInstrumentation::runAfterPass for details.
-void PassInstrumentor::runAfterPass(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : llvm::reverse(impl->instrumentations))
-    instr->runAfterPass(pass, ir);
+    instr->runAfterPass(pass, op);
 }
 
 /// See PassInstrumentation::runAfterPassFailed for details.
-void PassInstrumentor::runAfterPassFailed(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : llvm::reverse(impl->instrumentations))
-    instr->runAfterPassFailed(pass, ir);
+    instr->runAfterPassFailed(pass, op);
 }
 
 /// See PassInstrumentation::runBeforeAnalysis for details.
 void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
-                                         const llvm::Any &ir) {
+                                         Operation *op) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : impl->instrumentations)
-    instr->runBeforeAnalysis(name, id, ir);
+    instr->runBeforeAnalysis(name, id, op);
 }
 
 /// See PassInstrumentation::runAfterAnalysis for details.
 void PassInstrumentor::runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
-                                        const llvm::Any &ir) {
+                                        Operation *op) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : llvm::reverse(impl->instrumentations))
-    instr->runAfterAnalysis(name, id, ir);
+    instr->runAfterAnalysis(name, id, op);
 }
 
 /// Add the given instrumentation to the collection. This takes ownership over
index b4f3756..91b838c 100644 (file)
@@ -154,19 +154,16 @@ struct PassTiming : public PassInstrumentation {
   ~PassTiming() { print(); }
 
   /// Setup the instrumentation hooks.
-  void runBeforePass(Pass *pass, const llvm::Any &) override {
-    startPassTimer(pass);
-  }
-  void runAfterPass(Pass *pass, const llvm::Any &) override;
-  void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override {
-    runAfterPass(pass, ir);
+  void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); }
+  void runAfterPass(Pass *pass, Operation *) override;
+  void runAfterPassFailed(Pass *pass, Operation *op) override {
+    runAfterPass(pass, op);
   }
   void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
-                         const llvm::Any &) override {
+                         Operation *) override {
     startAnalysisTimer(name, id);
   }
-  void runAfterAnalysis(llvm::StringRef, AnalysisID *,
-                        const llvm::Any &) override;
+  void runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) override;
 
   /// Print and clear the timing results.
   void print();
@@ -243,7 +240,7 @@ void PassTiming::startAnalysisTimer(llvm::StringRef name, AnalysisID *id) {
 }
 
 /// Stop a pass timer.
-void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) {
+void PassTiming::runAfterPass(Pass *pass, Operation *) {
   auto tid = llvm::get_threadid();
   auto &activeTimers = activeThreadTimers[tid];
   assert(!activeTimers.empty() && "expected active timer");
@@ -277,8 +274,7 @@ void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) {
 }
 
 /// Stop a timer.
-void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *,
-                                  const llvm::Any &) {
+void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) {
   auto &activeTimers = activeThreadTimers[llvm::get_threadid()];
   assert(!activeTimers.empty() && "expected active timer");
   Timer *timer = activeTimers.pop_back_val();