#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
-/// The abstract base pass class. This class contains information describing the
-/// derived pass object, e.g its kind and abstract PassInfo.
-class Pass {
-public:
- enum class Kind { FunctionPass, ModulePass };
-
- virtual ~Pass() = default;
-
- /// Returns the unique identifier that corresponds to this pass.
- const PassID *getPassID() const { return passIDAndKind.getPointer(); }
-
- /// Returns the pass info for the specified pass class or null if unknown.
- static const PassInfo *lookupPassInfo(const PassID *passID);
- template <typename PassT> static const PassInfo *lookupPassInfo() {
- return lookupPassInfo(PassID::getID<PassT>());
- }
-
- /// Returns the pass info for this pass.
- const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
-
- /// Return the kind of this pass.
- Kind getKind() const { return passIDAndKind.getInt(); }
-
- /// Returns the derived pass name.
- virtual StringRef getName() = 0;
-
-protected:
- Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
-
-private:
- /// Out of line virtual method to ensure vtables and metadata are emitted to a
- /// single .o file.
- virtual void anchor();
-
- /// Represents a unique identifier for the pass and its kind.
- llvm::PointerIntPair<const PassID *, 1, Kind> passIDAndKind;
-};
-
namespace detail {
class FunctionPassExecutor;
class ModulePassExecutor;
/// The state for a single execution of a pass. This provides a unified
/// interface for accessing and initializing necessary state for pass execution.
-template <typename IRUnitT> struct PassExecutionState {
- PassExecutionState(IRUnitT ir, AnalysisManager analysisManager)
+struct PassExecutionState {
+ PassExecutionState(Operation *ir, AnalysisManager analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
- /// The current IR unit being transformed and a bool for if the pass signaled
- /// a failure.
- llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
+ /// The current operation being transformed and a bool for if the pass
+ /// signaled a failure.
+ llvm::PointerIntPair<Operation *, 1, bool> irAndPassFailed;
- /// The analysis manager for the IR unit.
+ /// The analysis manager for the operation.
AnalysisManager analysisManager;
/// The set of preserved analyses for the current execution.
};
} // namespace detail
-/// Pass to transform a specific function within a module. Derived passes should
-/// not inherit from this class directly, and instead should use the CRTP
-/// FunctionPass class.
-class FunctionPassBase : public Pass {
- using PassStateT = detail::PassExecutionState<FuncOp>;
-
+/// The abstract base pass class. This class contains information describing the
+/// derived pass object, e.g its kind and abstract PassInfo.
+class Pass {
public:
- static bool classof(const Pass *pass) {
- return pass->getKind() == Kind::FunctionPass;
- }
+ virtual ~Pass() = default;
-protected:
- explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {}
+ /// Returns the unique identifier that corresponds to this pass.
+ const PassID *getPassID() const { return passID; }
- /// The polymorphic API that runs the pass over the currently held function.
- virtual void runOnFunction() = 0;
+ /// Returns the pass info for the specified pass class or null if unknown.
+ static const PassInfo *lookupPassInfo(const PassID *passID);
+ template <typename PassT> static const PassInfo *lookupPassInfo() {
+ return lookupPassInfo(PassID::getID<PassT>());
+ }
- /// A clone method to create a copy of this pass.
- virtual std::unique_ptr<FunctionPassBase> clone() const = 0;
+ /// Returns the pass info for this pass.
+ const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
- /// Return the current function being transformed.
- FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); }
+ /// Returns the derived pass name.
+ virtual StringRef getName() = 0;
- /// Return the MLIR context for the current function being transformed.
- MLIRContext &getContext() { return *getFunction().getContext(); }
+ /// Returns the name of the operation that this pass operates on.
+ StringRef getOpName() const { return opName; }
+
+protected:
+ Pass(const PassID *passID, StringRef opName)
+ : passID(passID), opName(opName) {}
/// Returns the current pass state.
- PassStateT &getPassState() {
+ detail::PassExecutionState &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
- /// Returns the current analysis manager.
- AnalysisManager getAnalysisManager() {
- return getPassState().analysisManager;
- }
-
-private:
- /// Forwarding function to execute this pass.
- LLVM_NODISCARD
- LogicalResult run(FuncOp fn, AnalysisManager am);
-
- /// The current execution state for the pass.
- llvm::Optional<PassStateT> passState;
-
- /// Allow access to 'run'.
- friend detail::FunctionPassExecutor;
-};
-
-/// Pass to transform a module. Derived passes should not inherit from this
-/// class directly, and instead should use the CRTP ModulePass class.
-class ModulePassBase : public Pass {
- using PassStateT = detail::PassExecutionState<ModuleOp>;
-
-public:
- static bool classof(const Pass *pass) {
- return pass->getKind() == Kind::ModulePass;
- }
-
-protected:
- explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {}
-
- /// The polymorphic API that runs the pass over the currently held module.
- virtual void runOnModule() = 0;
+ /// Return the MLIR context for the current function being transformed.
+ MLIRContext &getContext() { return *getOperation()->getContext(); }
- /// Return the current module being transformed.
- ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); }
+ /// The polymorphic API that runs the pass over the currently held operation.
+ virtual void runOnOperation() = 0;
- /// Return the MLIR context for the current module being transformed.
- MLIRContext &getContext() { return *getModule().getContext(); }
+ /// A clone method to create a copy of this pass.
+ virtual std::unique_ptr<Pass> clone() const = 0;
- /// Returns the current pass state.
- PassStateT &getPassState() {
- assert(passState && "pass state was never initialized");
- return *passState;
+ /// Return the current operation being transformed.
+ Operation *getOperation() {
+ return getPassState().irAndPassFailed.getPointer();
}
/// Returns the current analysis manager.
}
private:
- /// Forwarding function to execute this pass.
+ /// Forwarding function to execute this pass on the given operation.
LLVM_NODISCARD
- LogicalResult run(ModuleOp module, AnalysisManager am);
+ LogicalResult run(Operation *op, AnalysisManager am);
+
+ /// Out of line virtual method to ensure vtables and metadata are emitted to a
+ /// single .o file.
+ virtual void anchor();
+
+ /// Represents a unique identifier for the pass.
+ const PassID *passID;
+
+ /// The name of the operation that this pass operates on.
+ StringRef opName;
/// The current execution state for the pass.
- llvm::Optional<PassStateT> passState;
+ llvm::Optional<detail::PassExecutionState> passState;
- /// Allow access to 'run'.
+ /// Allow access to 'clone' and 'run'.
+ friend detail::FunctionPassExecutor;
friend detail::ModulePassExecutor;
};
namespace detail {
/// The opaque CRTP model of a pass. This class provides utilities for derived
/// pass execution and handles all of the necessary polymorphic API.
-template <typename IRUnitT, typename PassT, typename BasePassT>
+template <typename PassT, typename BasePassT>
class PassModel : public BasePassT {
public:
/// Support isa/dyn_cast functionality for the derived pass class.
}
protected:
- PassModel() : BasePassT(PassID::getID<PassT>()) {}
+ PassModel(StringRef opName) : BasePassT(PassID::getID<PassT>(), opName) {}
/// Signal that some invariant was broken when running. The IR is allowed to
/// be in an invalid state.
name.consume_front("(anonymous namespace)::");
return name;
}
+
+ /// A clone method to create a copy of this pass.
+ std::unique_ptr<Pass> clone() const override {
+ return std::make_unique<PassT>(*static_cast<const PassT *>(this));
+ }
+
+ /// Returns the analysis for the parent operation if it exists.
+ template <typename AnalysisT>
+ llvm::Optional<std::reference_wrapper<AnalysisT>>
+ getCachedParentAnalysis(Operation *parent) {
+ return this->getAnalysisManager()
+ .template getCachedParentAnalysis<AnalysisT>(parent);
+ }
+ template <typename AnalysisT>
+ llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
+ return this->getAnalysisManager()
+ .template getCachedParentAnalysis<AnalysisT>(
+ this->getOperation()->getParentOp());
+ }
+
+ /// Returns the analysis for the given child operation if it exists.
+ template <typename AnalysisT>
+ llvm::Optional<std::reference_wrapper<AnalysisT>>
+ getCachedChildAnalysis(Operation *child) {
+ return this->getAnalysisManager()
+ .template getCachedChildAnalysis<AnalysisT>(child);
+ }
+
+ /// Returns the analysis for the given child operation, or creates it if it
+ /// doesn't exist.
+ template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *child) {
+ return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(
+ child);
+ }
};
} // end namespace detail
+/// Utility base class for OpPass below to denote an opaque pass operating on a
+/// specific operation type.
+template <typename OpT> class OpPassBase : public Pass {
+public:
+ using Pass::Pass;
+
+ /// Support isa/dyn_cast functionality.
+ static bool classof(const Pass *pass) {
+ return pass->getOpName() == OpT::getOperationName();
+ }
+};
+
+/// Pass to transform an operation of a specific type.
+///
+/// Operation passes must not:
+/// - read or modify any other operations within the parent region, as
+/// other threads may be manipulating them concurrently.
+/// - modify any state within the parent operation, this includes adding
+/// additional operations.
+///
+/// Derived function passes are expected to provide the following:
+/// - A 'void runOnOperation()' method.
+template <typename OpT, typename PassT>
+class OpPass : public detail::PassModel<PassT, OpPassBase<OpT>> {
+protected:
+ OpPass()
+ : detail::PassModel<PassT, OpPassBase<OpT>>(OpT::getOperationName()) {}
+
+ /// Return the current operation being transformed.
+ OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
+};
+
/// A model for providing function pass specific utilities.
///
/// Function passes must not:
///
/// Derived function passes are expected to provide the following:
/// - A 'void runOnFunction()' method.
-template <typename T>
-struct FunctionPass : public detail::PassModel<FuncOp, T, FunctionPassBase> {
- /// Returns the analysis for the parent module if it exists.
- template <typename AnalysisT>
- llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
- return this->getAnalysisManager()
- .template getCachedParentAnalysis<AnalysisT>(
- this->getFunction().getParentOp());
- }
+template <typename T> struct FunctionPass : public OpPass<FuncOp, T> {
+ /// The polymorphic API that runs the pass over the currently held function.
+ virtual void runOnFunction() = 0;
- /// A clone method to create a copy of this pass.
- std::unique_ptr<FunctionPassBase> clone() const override {
- return std::make_unique<T>(*static_cast<const T *>(this));
+ /// The polymorphic API that runs the pass over the currently held operation.
+ void runOnOperation() final {
+ if (!getFunction().isExternal())
+ runOnFunction();
}
+
+ /// Return the current module being transformed.
+ FuncOp getFunction() { return this->getOperation(); }
};
/// A model for providing module pass specific utilities.
///
/// Derived module passes are expected to provide the following:
/// - A 'void runOnModule()' method.
-template <typename T>
-struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
- /// Returns the analysis for a child function.
- template <typename AnalysisT> AnalysisT &getFunctionAnalysis(FuncOp f) {
- return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(f);
- }
+template <typename T> struct ModulePass : public OpPass<ModuleOp, T> {
+ /// The polymorphic API that runs the pass over the currently held module.
+ virtual void runOnModule() = 0;
- /// Returns an existing analysis for a child function if it exists.
- template <typename AnalysisT>
- llvm::Optional<std::reference_wrapper<AnalysisT>>
- getCachedFunctionAnalysis(FuncOp f) {
- return this->getAnalysisManager()
- .template getCachedChildAnalysis<AnalysisT>(f);
- }
+ /// The polymorphic API that runs the pass over the currently held operation.
+ void runOnOperation() final { runOnModule(); }
+
+ /// Return the current module being transformed.
+ ModuleOp getModule() { return this->getOperation(); }
};
+
+/// Using directives defining legacy base classes.
+// TODO(riverriddle) These should be removed in favor of OpPassBase<T>.
+using FunctionPassBase = OpPassBase<FuncOp>;
+using ModulePassBase = OpPassBase<ModuleOp>;
} // end namespace mlir
#endif // MLIR_PASS_PASS_H
void Pass::anchor() {}
/// Forwarding function to execute this pass.
-LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
- // Initialize the pass state.
- passState.emplace(fn, am);
+LogicalResult Pass::run(Operation *op, AnalysisManager am) {
+ passState.emplace(op, am);
// Instrument before the pass has run.
auto pi = am.getPassInstrumentor();
if (pi)
- pi->runBeforePass(this, fn);
+ pi->runBeforePass(this, op);
- // Invoke the virtual runOnFunction function.
- runOnFunction();
+ // Invoke the virtual runOnOperation method.
+ runOnOperation();
// Invalidate any non preserved analyses.
am.invalidate(passState->preservedAnalyses);
bool passFailed = passState->irAndPassFailed.getInt();
if (pi) {
if (passFailed)
- pi->runAfterPassFailed(this, fn);
+ pi->runAfterPassFailed(this, op);
else
- pi->runAfterPass(this, fn);
- }
-
- // Return if the pass signaled a failure.
- return failure(passFailed);
-}
-
-/// Forwarding function to execute this pass.
-LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
- // Initialize the pass state.
- passState.emplace(module, am);
-
- // Instrument before the pass has run.
- auto pi = am.getPassInstrumentor();
- if (pi)
- pi->runBeforePass(this, module);
-
- // Invoke the virtual runOnModule function.
- runOnModule();
-
- // Invalidate any non preserved analyses.
- am.invalidate(passState->preservedAnalyses);
-
- // Instrument after the pass has run.
- bool passFailed = passState->irAndPassFailed.getInt();
- if (pi) {
- if (passFailed)
- pi->runAfterPassFailed(this, module);
- else
- pi->runAfterPass(this, module);
+ pi->runAfterPass(this, op);
}
// Return if the pass signaled a failure.
FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
: PassExecutor(Kind::FunctionExecutor) {
for (auto &pass : rhs.passes)
- addPass(pass->clone());
+ addPass(cast<FunctionPassBase>(pass->clone()));
}
/// Run all of the passes in this manager over the current function.
/// Add an opaque pass pointer to the current manager. This takes ownership
/// over the provided pass pointer.
void PassManager::addPass(std::unique_ptr<Pass> pass) {
- switch (pass->getKind()) {
- case Pass::Kind::FunctionPass:
- addPass(cast<FunctionPassBase>(std::move(pass)));
- break;
- case Pass::Kind::ModulePass:
- addPass(cast<ModulePassBase>(std::move(pass)));
- break;
- }
+ if (isa<FunctionPassBase>(pass.get()))
+ return addPass(cast<FunctionPassBase>(std::move(pass)));
+ addPass(cast<ModulePassBase>(std::move(pass)));
}
/// Add a module pass to the current manager. This takes ownership over the