Use PassGate from LLVMContext if any otherwise global one
authorEvgeniy Brevnov <ybrevnov@azul.com>
Tue, 1 Nov 2022 06:42:07 +0000 (13:42 +0700)
committerEvgeniy Brevnov <ybrevnov@azul.com>
Fri, 25 Nov 2022 08:13:04 +0000 (15:13 +0700)
Differential Revision: https://reviews.llvm.org/D137149

17 files changed:
clang/lib/CodeGen/BackendUtil.cpp
flang/lib/Frontend/FrontendActions.cpp
llvm/include/llvm/IR/OptBisect.h
llvm/include/llvm/Passes/StandardInstrumentations.h
llvm/lib/Analysis/CallGraphSCCPass.cpp
llvm/lib/Analysis/LoopPass.cpp
llvm/lib/Analysis/RegionPass.cpp
llvm/lib/IR/LLVMContextImpl.cpp
llvm/lib/IR/OptBisect.cpp
llvm/lib/IR/Pass.cpp
llvm/lib/LTO/LTOBackend.cpp
llvm/lib/LTO/ThinLTOCodeGenerator.cpp
llvm/lib/Passes/PassBuilderBindings.cpp
llvm/lib/Passes/StandardInstrumentations.cpp
llvm/tools/opt/NewPMDriver.cpp
llvm/unittests/IR/LegacyPassManagerTest.cpp
llvm/unittests/IR/PassManagerTest.cpp

index 99a3007..2b24196 100644 (file)
@@ -851,9 +851,10 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
   PrintPassOptions PrintPassOpts;
   PrintPassOpts.Indent = DebugPassStructure;
   PrintPassOpts.SkipAnalyses = DebugPassStructure;
-  StandardInstrumentations SI(CodeGenOpts.DebugPassManager ||
-                                  DebugPassStructure,
-                              /*VerifyEach*/ false, PrintPassOpts);
+  StandardInstrumentations SI(
+      TheModule->getContext(),
+      (CodeGenOpts.DebugPassManager || DebugPassStructure),
+      /*VerifyEach*/ false, PrintPassOpts);
   SI.registerCallbacks(PIC, &FAM);
   PassBuilder PB(TM.get(), PTO, PGOOpt, &PIC);
 
index 9deef99..990cb76 100644 (file)
@@ -694,7 +694,8 @@ void CodeGenAction::runOptimizationPipeline(llvm::raw_pwrite_stream &os) {
   llvm::PassInstrumentationCallbacks pic;
   llvm::PipelineTuningOptions pto;
   llvm::Optional<llvm::PGOOptions> pgoOpt;
-  llvm::StandardInstrumentations si(opts.DebugPassManager);
+  llvm::StandardInstrumentations si(
+      llvmModule->getContext(), opts.DebugPassManager);
   si.registerCallbacks(pic, &fam);
   llvm::PassBuilder pb(tm.get(), pto, pgoOpt, &pic);
 
index 14488bb..6ebb9be 100644 (file)
@@ -29,7 +29,8 @@ public:
 
   /// IRDescription is a textual description of the IR unit the pass is running
   /// over.
-  virtual bool shouldRunPass(const Pass *P, StringRef IRDescription) {
+  virtual bool shouldRunPass(const StringRef PassName,
+                             StringRef IRDescription) {
     return true;
   }
 
@@ -55,7 +56,8 @@ public:
   /// Checks the bisect limit to determine if the specified pass should run.
   ///
   /// This forwards to checkPass().
-  bool shouldRunPass(const Pass *P, StringRef IRDescription) override;
+  bool shouldRunPass(const StringRef PassName,
+                     StringRef IRDescription) override;
 
   /// isEnabled() should return true before calling shouldRunPass().
   bool isEnabled() const override { return BisectLimit != Disabled; }
@@ -89,7 +91,7 @@ private:
 
 /// Singleton instance of the OptBisect class, so multiple pass managers don't
 /// need to coordinate their uses of OptBisect.
-OptBisect &getOptBisector();
+OptPassGate &getGlobalPassGate();
 
 } // end namespace llvm
 
index 19d3cbc..760decc 100644 (file)
@@ -74,11 +74,12 @@ private:
   bool shouldRun(StringRef PassID, Any IR);
 };
 
-class OptBisectInstrumentation {
+class OptPassGateInstrumentation {
+  LLVMContext &Context;
   bool HasWrittenIR = false;
-
 public:
-  OptBisectInstrumentation() = default;
+  OptPassGateInstrumentation(LLVMContext &Context) : Context(Context) {}
+  bool shouldRun(StringRef PassName, Any IR);
   void registerCallbacks(PassInstrumentationCallbacks &PIC);
 };
 
@@ -528,7 +529,7 @@ class StandardInstrumentations {
   TimePassesHandler TimePasses;
   TimeProfilingPassesHandler TimeProfilingPasses;
   OptNoneInstrumentation OptNone;
-  OptBisectInstrumentation OptBisect;
+  OptPassGateInstrumentation OptPassGate;
   PreservedCFGCheckerInstrumentation PreservedCFGChecker;
   IRChangedPrinter PrintChangedIR;
   PseudoProbeVerifier PseudoProbeVerification;
@@ -540,7 +541,8 @@ class StandardInstrumentations {
   bool VerifyEach;
 
 public:
-  StandardInstrumentations(bool DebugLogging, bool VerifyEach = false,
+  StandardInstrumentations(LLVMContext &Context, bool DebugLogging,
+                           bool VerifyEach = false,
                            PrintPassOptions PrintPassOpts = PrintPassOptions());
 
   // Register all the standard instrumentation callbacks. If \p FAM is nullptr
index 8438f33..669464b 100644 (file)
@@ -751,7 +751,8 @@ static std::string getDescription(const CallGraphSCC &SCC) {
 bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const {
   OptPassGate &Gate =
       SCC.getCallGraph().getModule().getContext().getOptPassGate();
-  return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(SCC));
+  return Gate.isEnabled() &&
+         !Gate.shouldRunPass(this->getPassName(), getDescription(SCC));
 }
 
 char DummyCGSCCPass::ID = 0;
index 5d824ae..294dfd9 100644 (file)
@@ -373,7 +373,8 @@ bool LoopPass::skipLoop(const Loop *L) const {
     return false;
   // Check the opt bisect limit.
   OptPassGate &Gate = F->getContext().getOptPassGate();
-  if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(*L)))
+  if (Gate.isEnabled() &&
+      !Gate.shouldRunPass(this->getPassName(), getDescription(*L)))
     return true;
   // Check for the OptimizeNone attribute.
   if (F->hasOptNone()) {
index ddef3be..9ea7d71 100644 (file)
@@ -283,7 +283,8 @@ static std::string getDescription(const Region &R) {
 bool RegionPass::skipRegion(Region &R) const {
   Function &F = *R.getEntry()->getParent();
   OptPassGate &Gate = F.getContext().getOptPassGate();
-  if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(R)))
+  if (Gate.isEnabled() &&
+      !Gate.shouldRunPass(this->getPassName(), getDescription(R)))
     return true;
 
   if (F.hasOptNone()) {
index d7aaf00..e0e25f2 100644 (file)
@@ -240,7 +240,7 @@ void LLVMContextImpl::getSyncScopeNames(
 /// singleton OptBisect if not explicitly set.
 OptPassGate &LLVMContextImpl::getOptPassGate() const {
   if (!OPG)
-    OPG = &getOptBisector();
+    OPG = &getGlobalPassGate();
   return *OPG;
 }
 
index c9054db..893a5e5 100644 (file)
 
 using namespace llvm;
 
+static OptBisect &getOptBisector() {
+  static OptBisect OptBisector;
+  return OptBisector;
+}
+
 static cl::opt<int> OptBisectLimit("opt-bisect-limit", cl::Hidden,
                                    cl::init(OptBisect::Disabled), cl::Optional,
                                    cl::cb<void, int>([](int Limit) {
-                                     llvm::getOptBisector().setLimit(Limit);
+                                     getOptBisector().setLimit(Limit);
                                    }),
                                    cl::desc("Maximum optimization to perform"));
 
@@ -34,25 +39,16 @@ static void printPassMessage(const StringRef &Name, int PassNum,
          << "(" << PassNum << ") " << Name << " on " << TargetDesc << "\n";
 }
 
-bool OptBisect::shouldRunPass(const Pass *P, StringRef IRDescription) {
-  assert(isEnabled());
-
-  return checkPass(P->getPassName(), IRDescription);
-}
-
-bool OptBisect::checkPass(const StringRef PassName,
-                          const StringRef TargetDesc) {
+bool OptBisect::shouldRunPass(const StringRef PassName,
+                              StringRef IRDescription) {
   assert(isEnabled());
 
   int CurBisectNum = ++LastBisectNum;
   bool ShouldRun = (BisectLimit == -1 || CurBisectNum <= BisectLimit);
-  printPassMessage(PassName, CurBisectNum, TargetDesc, ShouldRun);
+  printPassMessage(PassName, CurBisectNum, IRDescription, ShouldRun);
   return ShouldRun;
 }
 
 const int OptBisect::Disabled;
 
-OptBisect &llvm::getOptBisector() {
-  static OptBisect OptBisector;
-  return OptBisector;
-}
+OptPassGate &llvm::getGlobalPassGate() { return getOptBisector(); }
index fe0bfd8..716d9d5 100644 (file)
@@ -62,7 +62,8 @@ static std::string getDescription(const Module &M) {
 
 bool ModulePass::skipModule(Module &M) const {
   OptPassGate &Gate = M.getContext().getOptPassGate();
-  return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(M));
+  return Gate.isEnabled() &&
+         !Gate.shouldRunPass(this->getPassName(), getDescription(M));
 }
 
 bool Pass::mustPreserveAnalysisID(char &AID) const {
@@ -172,7 +173,8 @@ static std::string getDescription(const Function &F) {
 
 bool FunctionPass::skipFunction(const Function &F) const {
   OptPassGate &Gate = F.getContext().getOptPassGate();
-  if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(F)))
+  if (Gate.isEnabled() &&
+      !Gate.shouldRunPass(this->getPassName(), getDescription(F)))
     return true;
 
   if (F.hasOptNone()) {
index 5a8f60d..9298e45 100644 (file)
@@ -256,7 +256,7 @@ static void runNewPMPasses(const Config &Conf, Module &Mod, TargetMachine *TM,
   ModuleAnalysisManager MAM;
 
   PassInstrumentationCallbacks PIC;
-  StandardInstrumentations SI(Conf.DebugPassManager);
+  StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager);
   SI.registerCallbacks(PIC, &FAM);
   PassBuilder PB(TM, Conf.PTO, PGOOpt, &PIC);
 
index 935d8ec..190ef3f 100644 (file)
@@ -244,7 +244,7 @@ static void optimizeModule(Module &TheModule, TargetMachine &TM,
   ModuleAnalysisManager MAM;
 
   PassInstrumentationCallbacks PIC;
-  StandardInstrumentations SI(DebugPassManager);
+  StandardInstrumentations SI(TheModule.getContext(), DebugPassManager);
   SI.registerCallbacks(PIC, &FAM);
   PipelineTuningOptions PTO;
   PTO.LoopVectorization = true;
index bad1fab..54108a6 100644 (file)
@@ -65,7 +65,7 @@ LLVMErrorRef LLVMRunPasses(LLVMModuleRef M, const char *Passes,
   PB.registerModuleAnalyses(MAM);
   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
 
-  StandardInstrumentations SI(Debug, VerifyEach);
+  StandardInstrumentations SI(Mod->getContext(), Debug, VerifyEach);
   SI.registerCallbacks(PIC, &FAM);
   ModulePassManager MPM;
   if (VerifyEach) {
index 311d773..7dd0765 100644 (file)
@@ -767,27 +767,35 @@ bool OptNoneInstrumentation::shouldRun(StringRef PassID, Any IR) {
   return ShouldRun;
 }
 
-void OptBisectInstrumentation::registerCallbacks(
+bool OptPassGateInstrumentation::shouldRun(StringRef PassName, Any IR) {
+  if (isIgnored(PassName))
+    return true;
+
+  bool ShouldRun =
+      Context.getOptPassGate().shouldRunPass(PassName, getIRName(IR));
+  if (!ShouldRun && !this->HasWrittenIR && !OptBisectPrintIRPath.empty()) {
+    // FIXME: print IR if limit is higher than number of opt-bisect
+    // invocations
+    this->HasWrittenIR = true;
+    const Module *M = unwrapModule(IR, /*Force=*/true);
+    assert((M && &M->getContext() == &Context) && "Missing/Mismatching Module");
+    std::error_code EC;
+    raw_fd_ostream OS(OptBisectPrintIRPath, EC);
+    if (EC)
+      report_fatal_error(errorCodeToError(EC));
+    M->print(OS, nullptr);
+  }
+  return ShouldRun;
+}
+
+void OptPassGateInstrumentation::registerCallbacks(
     PassInstrumentationCallbacks &PIC) {
-  if (!getOptBisector().isEnabled())
+  OptPassGate &PassGate = Context.getOptPassGate();
+  if (!PassGate.isEnabled())
     return;
-  PIC.registerShouldRunOptionalPassCallback([this](StringRef PassID, Any IR) {
-    if (isIgnored(PassID))
-      return true;
-    bool ShouldRun = getOptBisector().checkPass(PassID, getIRName(IR));
-    if (!ShouldRun && !this->HasWrittenIR && !OptBisectPrintIRPath.empty()) {
-      // FIXME: print IR if limit is higher than number of opt-bisect
-      // invocations
-      this->HasWrittenIR = true;
-      const Module *M = unwrapModule(IR, /*Force=*/true);
-      assert(M && "expected Module");
-      std::error_code EC;
-      raw_fd_ostream OS(OptBisectPrintIRPath, EC);
-      if (EC)
-        report_fatal_error(errorCodeToError(EC));
-      M->print(OS, nullptr);
-    }
-    return ShouldRun;
+
+  PIC.registerShouldRunOptionalPassCallback([this](StringRef PassName, Any IR) {
+    return this->shouldRun(PassName, IR);
   });
 }
 
@@ -2037,8 +2045,11 @@ void DotCfgChangeReporter::registerCallbacks(
 }
 
 StandardInstrumentations::StandardInstrumentations(
-    bool DebugLogging, bool VerifyEach, PrintPassOptions PrintPassOpts)
-    : PrintPass(DebugLogging, PrintPassOpts), OptNone(DebugLogging),
+    LLVMContext &Context, bool DebugLogging, bool VerifyEach,
+    PrintPassOptions PrintPassOpts)
+    : PrintPass(DebugLogging, PrintPassOpts),
+      OptNone(DebugLogging),
+      OptPassGate(Context),
       PrintChangedIR(PrintChanged == ChangePrinter::Verbose),
       PrintChangedDiff(PrintChanged == ChangePrinter::DiffVerbose ||
                            PrintChanged == ChangePrinter::ColourDiffVerbose,
@@ -2099,7 +2110,7 @@ void StandardInstrumentations::registerCallbacks(
   PrintPass.registerCallbacks(PIC);
   TimePasses.registerCallbacks(PIC);
   OptNone.registerCallbacks(PIC);
-  OptBisect.registerCallbacks(PIC);
+  OptPassGate.registerCallbacks(PIC);
   if (FAM)
     PreservedCFGChecker.registerCallbacks(PIC, *FAM);
   PrintChangedIR.registerCallbacks(PIC);
index 3bf6333..884ffa1 100644 (file)
@@ -354,8 +354,8 @@ bool llvm::runPassPipeline(StringRef Arg0, Module &M, TargetMachine *TM,
   PrintPassOptions PrintPassOpts;
   PrintPassOpts.Verbose = DebugPM == DebugLogging::Verbose;
   PrintPassOpts.SkipAnalyses = DebugPM == DebugLogging::Quiet;
-  StandardInstrumentations SI(DebugPM != DebugLogging::None, VerifyEachPass,
-                              PrintPassOpts);
+  StandardInstrumentations SI(M.getContext(), DebugPM != DebugLogging::None,
+                              VerifyEachPass, PrintPassOpts);
   SI.registerCallbacks(PIC, &FAM);
   DebugifyEachInstrumentation Debugify;
   DebugifyStatsMap DIStatsMap;
index f674427..0c8a213 100644 (file)
@@ -359,10 +359,8 @@ namespace llvm {
     struct CustomOptPassGate : public OptPassGate {
       bool Skip;
       CustomOptPassGate(bool Skip) : Skip(Skip) { }
-      bool shouldRunPass(const Pass *P, StringRef IRDescription) override {
-        if (P->getPassKind() == PT_Module)
-          return !Skip;
-        return OptPassGate::shouldRunPass(P, IRDescription);
+      bool shouldRunPass(const StringRef PassName, StringRef IRDescription) override {
+        return !Skip;
       }
       bool isEnabled() const override { return true; }
     };
index 98f516f..bae5d46 100644 (file)
@@ -826,7 +826,7 @@ TEST_F(PassManagerTest, FunctionPassCFGChecker) {
   FunctionAnalysisManager FAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
-  StandardInstrumentations SI(/*DebugLogging*/ true);
+  StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
   SI.registerCallbacks(PIC, &FAM);
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });
@@ -872,7 +872,7 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerInvalidateAnalysis) {
   FunctionAnalysisManager FAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
-  StandardInstrumentations SI(/*DebugLogging*/ true);
+  StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
   SI.registerCallbacks(PIC, &FAM);
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });
@@ -937,7 +937,7 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerWrapped) {
   FunctionAnalysisManager FAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
-  StandardInstrumentations SI(/*DebugLogging*/ true);
+  StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
   SI.registerCallbacks(PIC, &FAM);
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });