From 1dd9bf47393afba3adad066fde950eb9c608f2cc Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 30 Aug 2019 13:16:13 -0700 Subject: [PATCH] Generalize the pass hierarchy by adding a general OpPass. This pass class generalizes the current functionality between FunctionPass and ModulePass, and allows for operating on any operation type. The pass manager currently only supports OpPasses operating on FuncOp and ModuleOp, but this restriction will be relaxed in follow-up changes. A utility class OpPassBase allows for generically referring to operation specific passes: e.g. FunctionPassBase == OpPassBase. PiperOrigin-RevId: 266442239 --- .../Linalg/Linalg1/include/linalg1/Passes.h | 4 +- .../Linalg3/include/linalg3/Transforms.h | 4 +- mlir/include/mlir/Analysis/Passes.h | 4 +- .../ConvertControlFlowToCFG.h | 3 +- .../mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h | 5 +- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 5 +- .../Conversion/LoopsToGPU/LoopsToGPUPass.h | 4 +- .../ConvertStandardToLLVMPass.h | 3 +- .../Conversion/VectorToLLVM/VectorToLLVM.h | 5 +- mlir/include/mlir/Dialect/FxpMathOps/Passes.h | 4 +- mlir/include/mlir/Dialect/GPU/Passes.h | 4 +- mlir/include/mlir/Dialect/Linalg/Passes.h | 7 +- mlir/include/mlir/Dialect/QuantOps/Passes.h | 4 +- mlir/include/mlir/Pass/Pass.h | 278 +++++++++--------- mlir/include/mlir/Pass/PassManager.h | 7 +- mlir/include/mlir/Transforms/Passes.h | 7 +- .../include/mlir/Transforms/ViewRegionGraph.h | 4 +- mlir/lib/Pass/Pass.cpp | 57 +--- mlir/lib/Pass/PassDetail.h | 4 +- 19 files changed, 211 insertions(+), 202 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h index d87fa3b1ca1f..0347e182a50b 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h @@ -27,7 +27,9 @@ #include namespace mlir { -class ModulePassBase; +class ModuleOp; +template class OpPassBase; +using ModulePassBase = OpPassBase; } // namespace mlir namespace linalg { diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 123d6afba088..849d65a6b6f6 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -27,9 +27,11 @@ namespace mlir { class AffineForOp; class AffineMap; class FuncOp; -class FunctionPassBase; class Operation; class Value; + +template class OpPassBase; +using FunctionPassBase = OpPassBase; } // namespace mlir namespace linalg { diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index 9eafcd355762..8c947e6c2226 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -27,7 +27,9 @@ namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; /// Creates a pass to check memref accesses in a Function. FunctionPassBase *createMemRefBoundCheckPass(); diff --git a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index 78e4356607fe..e6bf621cd7c1 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -23,9 +23,10 @@ namespace mlir { class FuncOp; -class FunctionPassBase; struct LogicalResult; class MLIRContext; +template class OpPassBase; +using FunctionPassBase = OpPassBase; class RewritePattern; // Owning list of rewriting patterns. diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index b8b7a1e37ef1..8d5c50135995 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -27,7 +27,7 @@ namespace mlir { class FuncOp; class Location; -class ModulePassBase; +class ModuleOp; class OpBuilder; class Value; @@ -35,6 +35,9 @@ namespace LLVM { class LLVMDialect; } +template class OpPassBase; +using ModulePassBase = OpPassBase; + using OwnedCubin = std::unique_ptr>; using CubinGenerator = std::function; diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 35f231464f1a..01e50baa592f 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -21,9 +21,12 @@ namespace mlir { class LLVMTypeConverter; -class ModulePassBase; class OwningRewritePatternList; +class ModuleOp; +template class OpPassBase; +using ModulePassBase = OpPassBase; + /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 3d32c36c43c0..9ef21ea97b6e 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -20,7 +20,9 @@ #include namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; /// Create a pass that converts loop nests into GPU kernels. It considers /// top-level affine.for and linalg.for operations as roots of loop nests and diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index d2f416b35fe9..10aa8ff96287 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -33,7 +33,8 @@ class LLVMTypeConverter; struct LogicalResult; class MLIRContext; class ModuleOp; -class ModulePassBase; +template class OpPassBase; +using ModulePassBase = OpPassBase; class RewritePattern; class Type; diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h index 7334c67e0d3d..c781858a6727 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h @@ -19,9 +19,12 @@ namespace mlir { class LLVMTypeConverter; -class ModulePassBase; +class ModuleOp; class OwningRewritePatternList; +template class OpPassBase; +using ModulePassBase = OpPassBase; + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h index 74c634a68895..f4099ab77548 100644 --- a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h +++ b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h @@ -23,7 +23,9 @@ #define MLIR_DIALECT_FXPMATHOPS_PASSES_H namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; namespace fxpmath { diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h index d562b5835c7a..14a9f013c99c 100644 --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -26,7 +26,9 @@ namespace mlir { -class ModulePassBase; +class ModuleOp; +template class OpPassBase; +using ModulePassBase = OpPassBase; std::unique_ptr createGpuKernelOutliningPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index e17439f6eea9..118e278ef60d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -26,8 +26,11 @@ #include "llvm/ADT/ArrayRef.h" namespace mlir { -class FunctionPassBase; -class ModulePassBase; +class FuncOp; +class ModuleOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; namespace linalg { std::unique_ptr diff --git a/mlir/include/mlir/Dialect/QuantOps/Passes.h b/mlir/include/mlir/Dialect/QuantOps/Passes.h index 1d43f7087db1..5e5fd700f920 100644 --- a/mlir/include/mlir/Dialect/QuantOps/Passes.h +++ b/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -28,7 +28,9 @@ #include namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; namespace quant { diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 360eaaff9b3d..79ede0e92e04 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -25,59 +25,21 @@ #include "llvm/ADT/PointerIntPair.h" namespace mlir { -/// The abstract base pass class. This class contains information describing the -/// derived pass object, e.g its kind and abstract PassInfo. -class Pass { -public: - enum class Kind { FunctionPass, ModulePass }; - - virtual ~Pass() = default; - - /// Returns the unique identifier that corresponds to this pass. - const PassID *getPassID() const { return passIDAndKind.getPointer(); } - - /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(const PassID *passID); - template static const PassInfo *lookupPassInfo() { - return lookupPassInfo(PassID::getID()); - } - - /// Returns the pass info for this pass. - const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); } - - /// Return the kind of this pass. - Kind getKind() const { return passIDAndKind.getInt(); } - - /// Returns the derived pass name. - virtual StringRef getName() = 0; - -protected: - Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {} - -private: - /// Out of line virtual method to ensure vtables and metadata are emitted to a - /// single .o file. - virtual void anchor(); - - /// Represents a unique identifier for the pass and its kind. - llvm::PointerIntPair passIDAndKind; -}; - namespace detail { class FunctionPassExecutor; class ModulePassExecutor; /// The state for a single execution of a pass. This provides a unified /// interface for accessing and initializing necessary state for pass execution. -template struct PassExecutionState { - PassExecutionState(IRUnitT ir, AnalysisManager analysisManager) +struct PassExecutionState { + PassExecutionState(Operation *ir, AnalysisManager analysisManager) : irAndPassFailed(ir, false), analysisManager(analysisManager) {} - /// The current IR unit being transformed and a bool for if the pass signaled - /// a failure. - llvm::PointerIntPair irAndPassFailed; + /// The current operation being transformed and a bool for if the pass + /// signaled a failure. + llvm::PointerIntPair irAndPassFailed; - /// The analysis manager for the IR unit. + /// The analysis manager for the operation. AnalysisManager analysisManager; /// The set of preserved analyses for the current execution. @@ -85,81 +47,52 @@ template struct PassExecutionState { }; } // namespace detail -/// Pass to transform a specific function within a module. Derived passes should -/// not inherit from this class directly, and instead should use the CRTP -/// FunctionPass class. -class FunctionPassBase : public Pass { - using PassStateT = detail::PassExecutionState; - +/// The abstract base pass class. This class contains information describing the +/// derived pass object, e.g its kind and abstract PassInfo. +class Pass { public: - static bool classof(const Pass *pass) { - return pass->getKind() == Kind::FunctionPass; - } + virtual ~Pass() = default; -protected: - explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {} + /// Returns the unique identifier that corresponds to this pass. + const PassID *getPassID() const { return passID; } - /// The polymorphic API that runs the pass over the currently held function. - virtual void runOnFunction() = 0; + /// Returns the pass info for the specified pass class or null if unknown. + static const PassInfo *lookupPassInfo(const PassID *passID); + template static const PassInfo *lookupPassInfo() { + return lookupPassInfo(PassID::getID()); + } - /// A clone method to create a copy of this pass. - virtual std::unique_ptr clone() const = 0; + /// Returns the pass info for this pass. + const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); } - /// Return the current function being transformed. - FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); } + /// Returns the derived pass name. + virtual StringRef getName() = 0; - /// Return the MLIR context for the current function being transformed. - MLIRContext &getContext() { return *getFunction().getContext(); } + /// Returns the name of the operation that this pass operates on. + StringRef getOpName() const { return opName; } + +protected: + Pass(const PassID *passID, StringRef opName) + : passID(passID), opName(opName) {} /// Returns the current pass state. - PassStateT &getPassState() { + detail::PassExecutionState &getPassState() { assert(passState && "pass state was never initialized"); return *passState; } - /// Returns the current analysis manager. - AnalysisManager getAnalysisManager() { - return getPassState().analysisManager; - } - -private: - /// Forwarding function to execute this pass. - LLVM_NODISCARD - LogicalResult run(FuncOp fn, AnalysisManager am); - - /// The current execution state for the pass. - llvm::Optional passState; - - /// Allow access to 'run'. - friend detail::FunctionPassExecutor; -}; - -/// Pass to transform a module. Derived passes should not inherit from this -/// class directly, and instead should use the CRTP ModulePass class. -class ModulePassBase : public Pass { - using PassStateT = detail::PassExecutionState; - -public: - static bool classof(const Pass *pass) { - return pass->getKind() == Kind::ModulePass; - } - -protected: - explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {} - - /// The polymorphic API that runs the pass over the currently held module. - virtual void runOnModule() = 0; + /// Return the MLIR context for the current function being transformed. + MLIRContext &getContext() { return *getOperation()->getContext(); } - /// Return the current module being transformed. - ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); } + /// The polymorphic API that runs the pass over the currently held operation. + virtual void runOnOperation() = 0; - /// Return the MLIR context for the current module being transformed. - MLIRContext &getContext() { return *getModule().getContext(); } + /// A clone method to create a copy of this pass. + virtual std::unique_ptr clone() const = 0; - /// Returns the current pass state. - PassStateT &getPassState() { - assert(passState && "pass state was never initialized"); - return *passState; + /// Return the current operation being transformed. + Operation *getOperation() { + return getPassState().irAndPassFailed.getPointer(); } /// Returns the current analysis manager. @@ -168,14 +101,25 @@ protected: } private: - /// Forwarding function to execute this pass. + /// Forwarding function to execute this pass on the given operation. LLVM_NODISCARD - LogicalResult run(ModuleOp module, AnalysisManager am); + LogicalResult run(Operation *op, AnalysisManager am); + + /// Out of line virtual method to ensure vtables and metadata are emitted to a + /// single .o file. + virtual void anchor(); + + /// Represents a unique identifier for the pass. + const PassID *passID; + + /// The name of the operation that this pass operates on. + StringRef opName; /// The current execution state for the pass. - llvm::Optional passState; + llvm::Optional passState; - /// Allow access to 'run'. + /// Allow access to 'clone' and 'run'. + friend detail::FunctionPassExecutor; friend detail::ModulePassExecutor; }; @@ -185,7 +129,7 @@ private: namespace detail { /// The opaque CRTP model of a pass. This class provides utilities for derived /// pass execution and handles all of the necessary polymorphic API. -template +template class PassModel : public BasePassT { public: /// Support isa/dyn_cast functionality for the derived pass class. @@ -194,7 +138,7 @@ public: } protected: - PassModel() : BasePassT(PassID::getID()) {} + PassModel(StringRef opName) : BasePassT(PassID::getID(), opName) {} /// Signal that some invariant was broken when running. The IR is allowed to /// be in an invalid state. @@ -234,9 +178,75 @@ protected: name.consume_front("(anonymous namespace)::"); return name; } + + /// A clone method to create a copy of this pass. + std::unique_ptr clone() const override { + return std::make_unique(*static_cast(this)); + } + + /// Returns the analysis for the parent operation if it exists. + template + llvm::Optional> + getCachedParentAnalysis(Operation *parent) { + return this->getAnalysisManager() + .template getCachedParentAnalysis(parent); + } + template + llvm::Optional> getCachedParentAnalysis() { + return this->getAnalysisManager() + .template getCachedParentAnalysis( + this->getOperation()->getParentOp()); + } + + /// Returns the analysis for the given child operation if it exists. + template + llvm::Optional> + getCachedChildAnalysis(Operation *child) { + return this->getAnalysisManager() + .template getCachedChildAnalysis(child); + } + + /// Returns the analysis for the given child operation, or creates it if it + /// doesn't exist. + template AnalysisT &getChildAnalysis(Operation *child) { + return this->getAnalysisManager().template getChildAnalysis( + child); + } }; } // end namespace detail +/// Utility base class for OpPass below to denote an opaque pass operating on a +/// specific operation type. +template class OpPassBase : public Pass { +public: + using Pass::Pass; + + /// Support isa/dyn_cast functionality. + static bool classof(const Pass *pass) { + return pass->getOpName() == OpT::getOperationName(); + } +}; + +/// Pass to transform an operation of a specific type. +/// +/// Operation passes must not: +/// - read or modify any other operations within the parent region, as +/// other threads may be manipulating them concurrently. +/// - modify any state within the parent operation, this includes adding +/// additional operations. +/// +/// Derived function passes are expected to provide the following: +/// - A 'void runOnOperation()' method. +template +class OpPass : public detail::PassModel> { +protected: + OpPass() + : detail::PassModel>(OpT::getOperationName()) {} + + /// Return the current operation being transformed. + OpT getOperation() { return cast(Pass::getOperation()); } +}; + /// A model for providing function pass specific utilities. /// /// Function passes must not: @@ -247,41 +257,39 @@ protected: /// /// Derived function passes are expected to provide the following: /// - A 'void runOnFunction()' method. -template -struct FunctionPass : public detail::PassModel { - /// Returns the analysis for the parent module if it exists. - template - llvm::Optional> getCachedModuleAnalysis() { - return this->getAnalysisManager() - .template getCachedParentAnalysis( - this->getFunction().getParentOp()); - } +template struct FunctionPass : public OpPass { + /// The polymorphic API that runs the pass over the currently held function. + virtual void runOnFunction() = 0; - /// A clone method to create a copy of this pass. - std::unique_ptr clone() const override { - return std::make_unique(*static_cast(this)); + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { + if (!getFunction().isExternal()) + runOnFunction(); } + + /// Return the current module being transformed. + FuncOp getFunction() { return this->getOperation(); } }; /// A model for providing module pass specific utilities. /// /// Derived module passes are expected to provide the following: /// - A 'void runOnModule()' method. -template -struct ModulePass : public detail::PassModel { - /// Returns the analysis for a child function. - template AnalysisT &getFunctionAnalysis(FuncOp f) { - return this->getAnalysisManager().template getChildAnalysis(f); - } +template struct ModulePass : public OpPass { + /// The polymorphic API that runs the pass over the currently held module. + virtual void runOnModule() = 0; - /// Returns an existing analysis for a child function if it exists. - template - llvm::Optional> - getCachedFunctionAnalysis(FuncOp f) { - return this->getAnalysisManager() - .template getCachedChildAnalysis(f); - } + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { runOnModule(); } + + /// Return the current module being transformed. + ModuleOp getModule() { return this->getOperation(); } }; + +/// Using directives defining legacy base classes. +// TODO(riverriddle) These should be removed in favor of OpPassBase. +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index b01445eae4cb..888d903a2941 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -26,9 +26,12 @@ class Any; } // end namespace llvm namespace mlir { -class FunctionPassBase; +class FuncOp; class ModuleOp; -class ModulePassBase; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; + class Pass; class PassInstrumentation; class PassInstrumentor; diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 693c7b0ae00b..dc3a213f09a0 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -30,8 +30,11 @@ namespace mlir { class AffineForOp; -class FunctionPassBase; -class ModulePassBase; +class FuncOp; +class ModuleOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; /// Creates a constant folding pass. Note that this pass solely provides simple /// top-down constant folding functionality; it is intended to be used for diff --git a/mlir/include/mlir/Transforms/ViewRegionGraph.h b/mlir/include/mlir/Transforms/ViewRegionGraph.h index 61da9f11f19f..f54d35643eb0 100644 --- a/mlir/include/mlir/Transforms/ViewRegionGraph.h +++ b/mlir/include/mlir/Transforms/ViewRegionGraph.h @@ -27,7 +27,9 @@ #include "llvm/Support/raw_ostream.h" namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; class Region; /// Displays the CFG in a window. This is for use from the debugger and diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0892aa087e84..e208e2029d93 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -42,17 +42,16 @@ using namespace mlir::detail; void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) { - // Initialize the pass state. - passState.emplace(fn, am); +LogicalResult Pass::run(Operation *op, AnalysisManager am) { + passState.emplace(op, am); // Instrument before the pass has run. auto pi = am.getPassInstrumentor(); if (pi) - pi->runBeforePass(this, fn); + pi->runBeforePass(this, op); - // Invoke the virtual runOnFunction function. - runOnFunction(); + // Invoke the virtual runOnOperation method. + runOnOperation(); // Invalidate any non preserved analyses. am.invalidate(passState->preservedAnalyses); @@ -61,38 +60,9 @@ LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) { bool passFailed = passState->irAndPassFailed.getInt(); if (pi) { if (passFailed) - pi->runAfterPassFailed(this, fn); + pi->runAfterPassFailed(this, op); else - pi->runAfterPass(this, fn); - } - - // Return if the pass signaled a failure. - return failure(passFailed); -} - -/// Forwarding function to execute this pass. -LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) { - // Initialize the pass state. - passState.emplace(module, am); - - // Instrument before the pass has run. - auto pi = am.getPassInstrumentor(); - if (pi) - pi->runBeforePass(this, module); - - // Invoke the virtual runOnModule function. - runOnModule(); - - // Invalidate any non preserved analyses. - am.invalidate(passState->preservedAnalyses); - - // Instrument after the pass has run. - bool passFailed = passState->irAndPassFailed.getInt(); - if (pi) { - if (passFailed) - pi->runAfterPassFailed(this, module); - else - pi->runAfterPass(this, module); + pi->runAfterPass(this, op); } // Return if the pass signaled a failure. @@ -106,7 +76,7 @@ LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) { FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) : PassExecutor(Kind::FunctionExecutor) { for (auto &pass : rhs.passes) - addPass(pass->clone()); + addPass(cast(pass->clone())); } /// Run all of the passes in this manager over the current function. @@ -265,14 +235,9 @@ void PassManager::disableMultithreading(bool disable) { /// Add an opaque pass pointer to the current manager. This takes ownership /// over the provided pass pointer. void PassManager::addPass(std::unique_ptr pass) { - switch (pass->getKind()) { - case Pass::Kind::FunctionPass: - addPass(cast(std::move(pass))); - break; - case Pass::Kind::ModulePass: - addPass(cast(std::move(pass))); - break; - } + if (isa(pass.get())) + return addPass(cast(std::move(pass))); + addPass(cast(std::move(pass))); } /// Add a module pass to the current manager. This takes ownership over the diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index aa60cfb23ea7..40e75f4cf3ad 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -66,7 +66,7 @@ public: /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. - void addPass(std::unique_ptr pass) { + void addPass(std::unique_ptr pass) { passes.push_back(std::move(pass)); } @@ -78,7 +78,7 @@ public: } private: - std::vector> passes; + std::vector> passes; }; /// A pass executor that contains a list of passes over a module unit. -- 2.34.1