Generalize the pass hierarchy by adding a general OpPass<PassT, OpT>.
authorRiver Riddle <riverriddle@google.com>
Fri, 30 Aug 2019 20:16:13 +0000 (13:16 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 30 Aug 2019 20:16:37 +0000 (13:16 -0700)
This pass class generalizes the current functionality between FunctionPass and ModulePass, and allows for operating on any operation type. The pass manager currently only supports OpPasses operating on FuncOp and ModuleOp, but this restriction will be relaxed in follow-up changes. A utility class OpPassBase<OpT> allows for generically referring to operation specific passes: e.g. FunctionPassBase == OpPassBase<FuncOp>.

PiperOrigin-RevId: 266442239

19 files changed:
mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h
mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h
mlir/include/mlir/Analysis/Passes.h
mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
mlir/include/mlir/Dialect/FxpMathOps/Passes.h
mlir/include/mlir/Dialect/GPU/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/QuantOps/Passes.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/ViewRegionGraph.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h

index d87fa3b..0347e18 100644 (file)
@@ -27,7 +27,9 @@
 #include <limits>
 
 namespace mlir {
-class ModulePassBase;
+class ModuleOp;
+template <typename T> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
 } // namespace mlir
 
 namespace linalg {
index 123d6af..849d65a 100644 (file)
@@ -27,9 +27,11 @@ namespace mlir {
 class AffineForOp;
 class AffineMap;
 class FuncOp;
-class FunctionPassBase;
 class Operation;
 class Value;
+
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 } // namespace mlir
 
 namespace linalg {
index 9eafcd3..8c947e6 100644 (file)
@@ -27,7 +27,9 @@
 
 namespace mlir {
 
-class FunctionPassBase;
+class FuncOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 
 /// Creates a pass to check memref accesses in a Function.
 FunctionPassBase *createMemRefBoundCheckPass();
index 78e4356..e6bf621 100644 (file)
 
 namespace mlir {
 class FuncOp;
-class FunctionPassBase;
 struct LogicalResult;
 class MLIRContext;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 class RewritePattern;
 
 // Owning list of rewriting patterns.
index b8b7a1e..8d5c501 100644 (file)
@@ -27,7 +27,7 @@ namespace mlir {
 
 class FuncOp;
 class Location;
-class ModulePassBase;
+class ModuleOp;
 class OpBuilder;
 class Value;
 
@@ -35,6 +35,9 @@ namespace LLVM {
 class LLVMDialect;
 }
 
+template <typename T> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
+
 using OwnedCubin = std::unique_ptr<std::vector<char>>;
 using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
 
index 35f2314..01e50ba 100644 (file)
 
 namespace mlir {
 class LLVMTypeConverter;
-class ModulePassBase;
 class OwningRewritePatternList;
 
+class ModuleOp;
+template <typename OpT> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
+
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                          OwningRewritePatternList &patterns);
index 3d32c36..9ef21ea 100644 (file)
@@ -20,7 +20,9 @@
 #include <memory>
 
 namespace mlir {
-class FunctionPassBase;
+class FuncOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 
 /// Create a pass that converts loop nests into GPU kernels.  It considers
 /// top-level affine.for and linalg.for operations as roots of loop nests and
index d2f416b..10aa8ff 100644 (file)
@@ -33,7 +33,8 @@ class LLVMTypeConverter;
 struct LogicalResult;
 class MLIRContext;
 class ModuleOp;
-class ModulePassBase;
+template <typename T> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
 class RewritePattern;
 class Type;
 
index 7334c67..c781858 100644 (file)
 
 namespace mlir {
 class LLVMTypeConverter;
-class ModulePassBase;
+class ModuleOp;
 class OwningRewritePatternList;
 
+template <typename T> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
+
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);
index 74c634a..f4099ab 100644 (file)
@@ -23,7 +23,9 @@
 #define MLIR_DIALECT_FXPMATHOPS_PASSES_H
 
 namespace mlir {
-class FunctionPassBase;
+class FuncOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 
 namespace fxpmath {
 
index d562b58..14a9f01 100644 (file)
@@ -26,7 +26,9 @@
 
 namespace mlir {
 
-class ModulePassBase;
+class ModuleOp;
+template <typename T> class OpPassBase;
+using ModulePassBase = OpPassBase<ModuleOp>;
 
 std::unique_ptr<ModulePassBase> createGpuKernelOutliningPass();
 
index e17439f..118e278 100644 (file)
 #include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
-class FunctionPassBase;
-class ModulePassBase;
+class FuncOp;
+class ModuleOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
+using ModulePassBase = OpPassBase<ModuleOp>;
 
 namespace linalg {
 std::unique_ptr<FunctionPassBase>
index 1d43f70..5e5fd70 100644 (file)
@@ -28,7 +28,9 @@
 #include <memory>
 
 namespace mlir {
-class FunctionPassBase;
+class FuncOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 
 namespace quant {
 
index 360eaaf..79ede0e 100644 (file)
 #include "llvm/ADT/PointerIntPair.h"
 
 namespace mlir {
-/// The abstract base pass class. This class contains information describing the
-/// derived pass object, e.g its kind and abstract PassInfo.
-class Pass {
-public:
-  enum class Kind { FunctionPass, ModulePass };
-
-  virtual ~Pass() = default;
-
-  /// Returns the unique identifier that corresponds to this pass.
-  const PassID *getPassID() const { return passIDAndKind.getPointer(); }
-
-  /// Returns the pass info for the specified pass class or null if unknown.
-  static const PassInfo *lookupPassInfo(const PassID *passID);
-  template <typename PassT> static const PassInfo *lookupPassInfo() {
-    return lookupPassInfo(PassID::getID<PassT>());
-  }
-
-  /// Returns the pass info for this pass.
-  const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
-
-  /// Return the kind of this pass.
-  Kind getKind() const { return passIDAndKind.getInt(); }
-
-  /// Returns the derived pass name.
-  virtual StringRef getName() = 0;
-
-protected:
-  Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
-
-private:
-  /// Out of line virtual method to ensure vtables and metadata are emitted to a
-  /// single .o file.
-  virtual void anchor();
-
-  /// Represents a unique identifier for the pass and its kind.
-  llvm::PointerIntPair<const PassID *, 1, Kind> passIDAndKind;
-};
-
 namespace detail {
 class FunctionPassExecutor;
 class ModulePassExecutor;
 
 /// The state for a single execution of a pass. This provides a unified
 /// interface for accessing and initializing necessary state for pass execution.
-template <typename IRUnitT> struct PassExecutionState {
-  PassExecutionState(IRUnitT ir, AnalysisManager analysisManager)
+struct PassExecutionState {
+  PassExecutionState(Operation *ir, AnalysisManager analysisManager)
       : irAndPassFailed(ir, false), analysisManager(analysisManager) {}
 
-  /// The current IR unit being transformed and a bool for if the pass signaled
-  /// a failure.
-  llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
+  /// The current operation being transformed and a bool for if the pass
+  /// signaled a failure.
+  llvm::PointerIntPair<Operation *, 1, bool> irAndPassFailed;
 
-  /// The analysis manager for the IR unit.
+  /// The analysis manager for the operation.
   AnalysisManager analysisManager;
 
   /// The set of preserved analyses for the current execution.
@@ -85,81 +47,52 @@ template <typename IRUnitT> struct PassExecutionState {
 };
 } // namespace detail
 
-/// Pass to transform a specific function within a module. Derived passes should
-/// not inherit from this class directly, and instead should use the CRTP
-/// FunctionPass class.
-class FunctionPassBase : public Pass {
-  using PassStateT = detail::PassExecutionState<FuncOp>;
-
+/// The abstract base pass class. This class contains information describing the
+/// derived pass object, e.g its kind and abstract PassInfo.
+class Pass {
 public:
-  static bool classof(const Pass *pass) {
-    return pass->getKind() == Kind::FunctionPass;
-  }
+  virtual ~Pass() = default;
 
-protected:
-  explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {}
+  /// Returns the unique identifier that corresponds to this pass.
+  const PassID *getPassID() const { return passID; }
 
-  /// The polymorphic API that runs the pass over the currently held function.
-  virtual void runOnFunction() = 0;
+  /// Returns the pass info for the specified pass class or null if unknown.
+  static const PassInfo *lookupPassInfo(const PassID *passID);
+  template <typename PassT> static const PassInfo *lookupPassInfo() {
+    return lookupPassInfo(PassID::getID<PassT>());
+  }
 
-  /// A clone method to create a copy of this pass.
-  virtual std::unique_ptr<FunctionPassBase> clone() const = 0;
+  /// Returns the pass info for this pass.
+  const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
 
-  /// Return the current function being transformed.
-  FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); }
+  /// Returns the derived pass name.
+  virtual StringRef getName() = 0;
 
-  /// Return the MLIR context for the current function being transformed.
-  MLIRContext &getContext() { return *getFunction().getContext(); }
+  /// Returns the name of the operation that this pass operates on.
+  StringRef getOpName() const { return opName; }
+
+protected:
+  Pass(const PassID *passID, StringRef opName)
+      : passID(passID), opName(opName) {}
 
   /// Returns the current pass state.
-  PassStateT &getPassState() {
+  detail::PassExecutionState &getPassState() {
     assert(passState && "pass state was never initialized");
     return *passState;
   }
 
-  /// Returns the current analysis manager.
-  AnalysisManager getAnalysisManager() {
-    return getPassState().analysisManager;
-  }
-
-private:
-  /// Forwarding function to execute this pass.
-  LLVM_NODISCARD
-  LogicalResult run(FuncOp fn, AnalysisManager am);
-
-  /// The current execution state for the pass.
-  llvm::Optional<PassStateT> passState;
-
-  /// Allow access to 'run'.
-  friend detail::FunctionPassExecutor;
-};
-
-/// Pass to transform a module. Derived passes should not inherit from this
-/// class directly, and instead should use the CRTP ModulePass class.
-class ModulePassBase : public Pass {
-  using PassStateT = detail::PassExecutionState<ModuleOp>;
-
-public:
-  static bool classof(const Pass *pass) {
-    return pass->getKind() == Kind::ModulePass;
-  }
-
-protected:
-  explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {}
-
-  /// The polymorphic API that runs the pass over the currently held module.
-  virtual void runOnModule() = 0;
+  /// Return the MLIR context for the current function being transformed.
+  MLIRContext &getContext() { return *getOperation()->getContext(); }
 
-  /// Return the current module being transformed.
-  ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); }
+  /// The polymorphic API that runs the pass over the currently held operation.
+  virtual void runOnOperation() = 0;
 
-  /// Return the MLIR context for the current module being transformed.
-  MLIRContext &getContext() { return *getModule().getContext(); }
+  /// A clone method to create a copy of this pass.
+  virtual std::unique_ptr<Pass> clone() const = 0;
 
-  /// Returns the current pass state.
-  PassStateT &getPassState() {
-    assert(passState && "pass state was never initialized");
-    return *passState;
+  /// Return the current operation being transformed.
+  Operation *getOperation() {
+    return getPassState().irAndPassFailed.getPointer();
   }
 
   /// Returns the current analysis manager.
@@ -168,14 +101,25 @@ protected:
   }
 
 private:
-  /// Forwarding function to execute this pass.
+  /// Forwarding function to execute this pass on the given operation.
   LLVM_NODISCARD
-  LogicalResult run(ModuleOp module, AnalysisManager am);
+  LogicalResult run(Operation *op, AnalysisManager am);
+
+  /// Out of line virtual method to ensure vtables and metadata are emitted to a
+  /// single .o file.
+  virtual void anchor();
+
+  /// Represents a unique identifier for the pass.
+  const PassID *passID;
+
+  /// The name of the operation that this pass operates on.
+  StringRef opName;
 
   /// The current execution state for the pass.
-  llvm::Optional<PassStateT> passState;
+  llvm::Optional<detail::PassExecutionState> passState;
 
-  /// Allow access to 'run'.
+  /// Allow access to 'clone' and 'run'.
+  friend detail::FunctionPassExecutor;
   friend detail::ModulePassExecutor;
 };
 
@@ -185,7 +129,7 @@ private:
 namespace detail {
 /// The opaque CRTP model of a pass. This class provides utilities for derived
 /// pass execution and handles all of the necessary polymorphic API.
-template <typename IRUnitT, typename PassT, typename BasePassT>
+template <typename PassT, typename BasePassT>
 class PassModel : public BasePassT {
 public:
   /// Support isa/dyn_cast functionality for the derived pass class.
@@ -194,7 +138,7 @@ public:
   }
 
 protected:
-  PassModel() : BasePassT(PassID::getID<PassT>()) {}
+  PassModel(StringRef opName) : BasePassT(PassID::getID<PassT>(), opName) {}
 
   /// Signal that some invariant was broken when running. The IR is allowed to
   /// be in an invalid state.
@@ -234,9 +178,75 @@ protected:
       name.consume_front("(anonymous namespace)::");
     return name;
   }
+
+  /// A clone method to create a copy of this pass.
+  std::unique_ptr<Pass> clone() const override {
+    return std::make_unique<PassT>(*static_cast<const PassT *>(this));
+  }
+
+  /// Returns the analysis for the parent operation if it exists.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedParentAnalysis(Operation *parent) {
+    return this->getAnalysisManager()
+        .template getCachedParentAnalysis<AnalysisT>(parent);
+  }
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
+    return this->getAnalysisManager()
+        .template getCachedParentAnalysis<AnalysisT>(
+            this->getOperation()->getParentOp());
+  }
+
+  /// Returns the analysis for the given child operation if it exists.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedChildAnalysis(Operation *child) {
+    return this->getAnalysisManager()
+        .template getCachedChildAnalysis<AnalysisT>(child);
+  }
+
+  /// Returns the analysis for the given child operation, or creates it if it
+  /// doesn't exist.
+  template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *child) {
+    return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(
+        child);
+  }
 };
 } // end namespace detail
 
+/// Utility base class for OpPass below to denote an opaque pass operating on a
+/// specific operation type.
+template <typename OpT> class OpPassBase : public Pass {
+public:
+  using Pass::Pass;
+
+  /// Support isa/dyn_cast functionality.
+  static bool classof(const Pass *pass) {
+    return pass->getOpName() == OpT::getOperationName();
+  }
+};
+
+/// Pass to transform an operation of a specific type.
+///
+/// Operation passes must not:
+///   - read or modify any other operations within the parent region, as
+///     other threads may be manipulating them concurrently.
+///   - modify any state within the parent operation, this includes adding
+///     additional operations.
+///
+/// Derived function passes are expected to provide the following:
+///   - A 'void runOnOperation()' method.
+template <typename OpT, typename PassT>
+class OpPass : public detail::PassModel<PassT, OpPassBase<OpT>> {
+protected:
+  OpPass()
+      : detail::PassModel<PassT, OpPassBase<OpT>>(OpT::getOperationName()) {}
+
+  /// Return the current operation being transformed.
+  OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
+};
+
 /// A model for providing function pass specific utilities.
 ///
 /// Function passes must not:
@@ -247,41 +257,39 @@ protected:
 ///
 /// Derived function passes are expected to provide the following:
 ///   - A 'void runOnFunction()' method.
-template <typename T>
-struct FunctionPass : public detail::PassModel<FuncOp, T, FunctionPassBase> {
-  /// Returns the analysis for the parent module if it exists.
-  template <typename AnalysisT>
-  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
-    return this->getAnalysisManager()
-        .template getCachedParentAnalysis<AnalysisT>(
-            this->getFunction().getParentOp());
-  }
+template <typename T> struct FunctionPass : public OpPass<FuncOp, T> {
+  /// The polymorphic API that runs the pass over the currently held function.
+  virtual void runOnFunction() = 0;
 
-  /// A clone method to create a copy of this pass.
-  std::unique_ptr<FunctionPassBase> clone() const override {
-    return std::make_unique<T>(*static_cast<const T *>(this));
+  /// The polymorphic API that runs the pass over the currently held operation.
+  void runOnOperation() final {
+    if (!getFunction().isExternal())
+      runOnFunction();
   }
+
+  /// Return the current module being transformed.
+  FuncOp getFunction() { return this->getOperation(); }
 };
 
 /// A model for providing module pass specific utilities.
 ///
 /// Derived module passes are expected to provide the following:
 ///   - A 'void runOnModule()' method.
-template <typename T>
-struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
-  /// Returns the analysis for a child function.
-  template <typename AnalysisT> AnalysisT &getFunctionAnalysis(FuncOp f) {
-    return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(f);
-  }
+template <typename T> struct ModulePass : public OpPass<ModuleOp, T> {
+  /// The polymorphic API that runs the pass over the currently held module.
+  virtual void runOnModule() = 0;
 
-  /// Returns an existing analysis for a child function if it exists.
-  template <typename AnalysisT>
-  llvm::Optional<std::reference_wrapper<AnalysisT>>
-  getCachedFunctionAnalysis(FuncOp f) {
-    return this->getAnalysisManager()
-        .template getCachedChildAnalysis<AnalysisT>(f);
-  }
+  /// The polymorphic API that runs the pass over the currently held operation.
+  void runOnOperation() final { runOnModule(); }
+
+  /// Return the current module being transformed.
+  ModuleOp getModule() { return this->getOperation(); }
 };
+
+/// Using directives defining legacy base classes.
+// TODO(riverriddle) These should be removed in favor of OpPassBase<T>.
+using FunctionPassBase = OpPassBase<FuncOp>;
+using ModulePassBase = OpPassBase<ModuleOp>;
 } // end namespace mlir
 
 #endif // MLIR_PASS_PASS_H
index b01445e..888d903 100644 (file)
@@ -26,9 +26,12 @@ class Any;
 } // end namespace llvm
 
 namespace mlir {
-class FunctionPassBase;
+class FuncOp;
 class ModuleOp;
-class ModulePassBase;
+template <typename OpT> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
+using ModulePassBase = OpPassBase<ModuleOp>;
+
 class Pass;
 class PassInstrumentation;
 class PassInstrumentor;
index 693c7b0..dc3a213 100644 (file)
 namespace mlir {
 
 class AffineForOp;
-class FunctionPassBase;
-class ModulePassBase;
+class FuncOp;
+class ModuleOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
+using ModulePassBase = OpPassBase<ModuleOp>;
 
 /// Creates a constant folding pass. Note that this pass solely provides simple
 /// top-down constant folding functionality; it is intended to be used for
index 61da9f1..f54d356 100644 (file)
@@ -27,7 +27,9 @@
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
-class FunctionPassBase;
+class FuncOp;
+template <typename T> class OpPassBase;
+using FunctionPassBase = OpPassBase<FuncOp>;
 class Region;
 
 /// Displays the CFG in a window. This is for use from the debugger and
index 0892aa0..e208e20 100644 (file)
@@ -42,17 +42,16 @@ using namespace mlir::detail;
 void Pass::anchor() {}
 
 /// Forwarding function to execute this pass.
-LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
-  // Initialize the pass state.
-  passState.emplace(fn, am);
+LogicalResult Pass::run(Operation *op, AnalysisManager am) {
+  passState.emplace(op, am);
 
   // Instrument before the pass has run.
   auto pi = am.getPassInstrumentor();
   if (pi)
-    pi->runBeforePass(this, fn);
+    pi->runBeforePass(this, op);
 
-  // Invoke the virtual runOnFunction function.
-  runOnFunction();
+  // Invoke the virtual runOnOperation method.
+  runOnOperation();
 
   // Invalidate any non preserved analyses.
   am.invalidate(passState->preservedAnalyses);
@@ -61,38 +60,9 @@ LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
   bool passFailed = passState->irAndPassFailed.getInt();
   if (pi) {
     if (passFailed)
-      pi->runAfterPassFailed(this, fn);
+      pi->runAfterPassFailed(this, op);
     else
-      pi->runAfterPass(this, fn);
-  }
-
-  // Return if the pass signaled a failure.
-  return failure(passFailed);
-}
-
-/// Forwarding function to execute this pass.
-LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
-  // Initialize the pass state.
-  passState.emplace(module, am);
-
-  // Instrument before the pass has run.
-  auto pi = am.getPassInstrumentor();
-  if (pi)
-    pi->runBeforePass(this, module);
-
-  // Invoke the virtual runOnModule function.
-  runOnModule();
-
-  // Invalidate any non preserved analyses.
-  am.invalidate(passState->preservedAnalyses);
-
-  // Instrument after the pass has run.
-  bool passFailed = passState->irAndPassFailed.getInt();
-  if (pi) {
-    if (passFailed)
-      pi->runAfterPassFailed(this, module);
-    else
-      pi->runAfterPass(this, module);
+      pi->runAfterPass(this, op);
   }
 
   // Return if the pass signaled a failure.
@@ -106,7 +76,7 @@ LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
 FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
     : PassExecutor(Kind::FunctionExecutor) {
   for (auto &pass : rhs.passes)
-    addPass(pass->clone());
+    addPass(cast<FunctionPassBase>(pass->clone()));
 }
 
 /// Run all of the passes in this manager over the current function.
@@ -265,14 +235,9 @@ void PassManager::disableMultithreading(bool disable) {
 /// Add an opaque pass pointer to the current manager. This takes ownership
 /// over the provided pass pointer.
 void PassManager::addPass(std::unique_ptr<Pass> pass) {
-  switch (pass->getKind()) {
-  case Pass::Kind::FunctionPass:
-    addPass(cast<FunctionPassBase>(std::move(pass)));
-    break;
-  case Pass::Kind::ModulePass:
-    addPass(cast<ModulePassBase>(std::move(pass)));
-    break;
-  }
+  if (isa<FunctionPassBase>(pass.get()))
+    return addPass(cast<FunctionPassBase>(std::move(pass)));
+  addPass(cast<ModulePassBase>(std::move(pass)));
 }
 
 /// Add a module pass to the current manager. This takes ownership over the
index aa60cfb..40e75f4 100644 (file)
@@ -66,7 +66,7 @@ public:
 
   /// Add a pass to the current executor. This takes ownership over the provided
   /// pass pointer.
-  void addPass(std::unique_ptr<FunctionPassBase> pass) {
+  void addPass(std::unique_ptr<Pass> pass) {
     passes.push_back(std::move(pass));
   }
 
@@ -78,7 +78,7 @@ public:
   }
 
 private:
-  std::vector<std::unique_ptr<FunctionPassBase>> passes;
+  std::vector<std::unique_ptr<Pass>> passes;
 };
 
 /// A pass executor that contains a list of passes over a module unit.