[mlir][Inliner] Refactor the inliner to use nested pass pipelines instead of just...
authorRiver Riddle <riddleriver@gmail.com>
Tue, 15 Dec 2020 02:07:45 +0000 (18:07 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 15 Dec 2020 02:09:47 +0000 (18:09 -0800)
Now that passes have support for running nested pipelines, the inliner can now allow for users to provide proper nested pipelines to use for optimization during inlining. This revision also changes the behavior of optimization during inlining to optimize before attempting to inline, which should lead to a more accurate cost model and prevents the need for users to schedule additional duplicate cleanup passes before/after the inliner that would already be run during inlining.

Differential Revision: https://reviews.llvm.org/D91211

16 files changed:
llvm/include/llvm/ADT/Sequence.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Pass/PassTiming.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/test/Dialect/Affine/inlining.mlir
mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
mlir/test/Pass/dynamic-pipeline-nested.mlir
mlir/test/Transforms/inlining.mlir
mlir/test/lib/Transforms/TestDynamicPipeline.cpp

index 8c505f2..8a695d7 100644 (file)
@@ -42,6 +42,10 @@ public:
   value_sequence_iterator(const value_sequence_iterator &) = default;
   value_sequence_iterator(value_sequence_iterator &&Arg)
       : Value(std::move(Arg.Value)) {}
+  value_sequence_iterator &operator=(const value_sequence_iterator &Arg) {
+    Value = Arg.Value;
+    return *this;
+  }
 
   template <typename U, typename Enabler = decltype(ValueT(std::declval<U>()))>
   value_sequence_iterator(U &&Value) : Value(std::forward<U>(Value)) {}
index ec6b769..5da0c95 100644 (file)
@@ -98,7 +98,7 @@ struct AnalysisConcept {
 /// A derived analysis model used to hold a specific analysis object.
 template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
   template <typename... Args>
-  explicit AnalysisModel(Args &&... args)
+  explicit AnalysisModel(Args &&...args)
       : analysis(std::forward<Args>(args)...) {}
 
   /// A hook used to query analyses for invalidation.
@@ -198,7 +198,10 @@ private:
 /// An analysis map that contains a map for the current operation, and a set of
 /// maps for any child operations.
 struct NestedAnalysisMap {
-  NestedAnalysisMap(Operation *op) : analyses(op) {}
+  NestedAnalysisMap(Operation *op, PassInstrumentor *instrumentor)
+      : analyses(op), parentOrInstrumentor(instrumentor) {}
+  NestedAnalysisMap(Operation *op, NestedAnalysisMap *parent)
+      : analyses(op), parentOrInstrumentor(parent) {}
 
   /// Get the operation for this analysis map.
   Operation *getOperation() const { return analyses.getOperation(); }
@@ -206,11 +209,34 @@ struct NestedAnalysisMap {
   /// Invalidate any non preserved analyses.
   void invalidate(const PreservedAnalyses &pa);
 
+  /// Returns the parent analysis map for this analysis map, or null if this is
+  /// the top-level map.
+  const NestedAnalysisMap *getParent() const {
+    return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
+  }
+
+  /// Returns a pass instrumentation object for the current operation. This
+  /// value may be null.
+  PassInstrumentor *getPassInstrumentor() const {
+    if (auto *parent = getParent())
+      return parent->getPassInstrumentor();
+    return parentOrInstrumentor.get<PassInstrumentor *>();
+  }
+
   /// The cached analyses for nested operations.
   DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
 
-  /// The analyses for the owning module.
+  /// The analyses for the owning operation.
   detail::AnalysisMap analyses;
+
+  /// This value has three possible states:
+  /// NestedAnalysisMap*: A pointer to the parent analysis map.
+  /// PassInstrumentor*: This analysis map is the top-level map, and this
+  ///                    pointer is the optional pass instrumentor for the
+  ///                    current compilation.
+  /// nullptr: This analysis map is the top-level map, and there is nop pass
+  ///          instrumentor.
+  PointerUnion<NestedAnalysisMap *, PassInstrumentor *> parentOrInstrumentor;
 };
 } // namespace detail
 
@@ -236,11 +262,11 @@ public:
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>>
   getCachedParentAnalysis(Operation *parentOp) const {
-    ParentPointerT curParent = parent;
-    while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
-      if (parentAM->impl->getOperation() == parentOp)
-        return parentAM->getCachedAnalysis<AnalysisT>();
-      curParent = parentAM->parent;
+    const detail::NestedAnalysisMap *curParent = impl;
+    while (auto *parentAM = curParent->getParent()) {
+      if (parentAM->getOperation() == parentOp)
+        return parentAM->analyses.getCachedAnalysis<AnalysisT>();
+      curParent = parentAM;
     }
     return None;
   }
@@ -286,7 +312,8 @@ public:
     return it->second->analyses.getCachedAnalysis<AnalysisT>();
   }
 
-  /// Get an analysis manager for the given child operation.
+  /// Get an analysis manager for the given operation, which must be a proper
+  /// descendant of the current operation represented by this analysis manager.
   AnalysisManager nest(Operation *op);
 
   /// Invalidate any non preserved analyses,
@@ -300,19 +327,15 @@ public:
 
   /// Returns a pass instrumentation object for the current operation. This
   /// value may be null.
-  PassInstrumentor *getPassInstrumentor() const;
+  PassInstrumentor *getPassInstrumentor() const {
+    return impl->getPassInstrumentor();
+  }
 
 private:
-  AnalysisManager(const AnalysisManager *parent,
-                  detail::NestedAnalysisMap *impl)
-      : parent(parent), impl(impl) {}
-  AnalysisManager(const ModuleAnalysisManager *parent,
-                  detail::NestedAnalysisMap *impl)
-      : parent(parent), impl(impl) {}
+  AnalysisManager(detail::NestedAnalysisMap *impl) : impl(impl) {}
 
-  /// A reference to the parent analysis manager, or the top-level module
-  /// analysis manager.
-  ParentPointerT parent;
+  /// Get an analysis manager for the given immediately nested child operation.
+  AnalysisManager nestImmediate(Operation *op);
 
   /// A reference to the impl analysis map within the parent analysis manager.
   detail::NestedAnalysisMap *impl;
@@ -328,23 +351,16 @@ private:
 class ModuleAnalysisManager {
 public:
   ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
-      : analyses(op), passInstrumentor(passInstrumentor) {}
+      : analyses(oppassInstrumentor) {}
   ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
   ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
 
-  /// Returns a pass instrumentation object for the current module. This value
-  /// may be null.
-  PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
-
   /// Returns an analysis manager for the current top-level module.
-  operator AnalysisManager() { return AnalysisManager(this, &analyses); }
+  operator AnalysisManager() { return AnalysisManager(&analyses); }
 
 private:
   /// The analyses for the owning module.
   detail::NestedAnalysisMap analyses;
-
-  /// An optional instrumentation object.
-  PassInstrumentor *passInstrumentor;
 };
 
 } // end namespace mlir
index 87f4f6b..7a95237 100644 (file)
@@ -95,7 +95,7 @@ public:
             typename OptionParser = detail::PassOptions::OptionParser<DataType>>
   struct Option : public detail::PassOptions::Option<DataType, OptionParser> {
     template <typename... Args>
-    Option(Pass &parent, StringRef arg, Args &&... args)
+    Option(Pass &parent, StringRef arg, Args &&...args)
         : detail::PassOptions::Option<DataType, OptionParser>(
               parent.passOptions, arg, std::forward<Args>(args)...) {}
     using detail::PassOptions::Option<DataType, OptionParser>::operator=;
@@ -107,14 +107,17 @@ public:
   struct ListOption
       : public detail::PassOptions::ListOption<DataType, OptionParser> {
     template <typename... Args>
-    ListOption(Pass &parent, StringRef arg, Args &&... args)
+    ListOption(Pass &parent, StringRef arg, Args &&...args)
         : detail::PassOptions::ListOption<DataType, OptionParser>(
               parent.passOptions, arg, std::forward<Args>(args)...) {}
     using detail::PassOptions::ListOption<DataType, OptionParser>::operator=;
   };
 
   /// Attempt to initialize the options of this pass from the given string.
-  LogicalResult initializeOptions(StringRef options);
+  /// Derived classes may override this method to hook into the point at which
+  /// options are initialized, but should generally always invoke this base
+  /// class variant.
+  virtual LogicalResult initializeOptions(StringRef options);
 
   /// Prints out the pass in the textual representation of pipelines. If this is
   /// an adaptor pass, print with the op_name(sub_pass,...) format.
@@ -265,7 +268,6 @@ protected:
   void copyOptionValuesFrom(const Pass *other);
 
 private:
-
   /// Out of line virtual method to ensure vtables and metadata are emitted to a
   /// single .o file.
   virtual void anchor();
index 5e9c9a7..2715ebd 100644 (file)
@@ -48,8 +48,8 @@ struct PassExecutionState;
 class OpPassManager {
 public:
   enum class Nesting { Implicit, Explicit };
-  OpPassManager(Identifier name, Nesting nesting);
-  OpPassManager(StringRef name, Nesting nesting);
+  OpPassManager(Identifier name, Nesting nesting = Nesting::Explicit);
+  OpPassManager(StringRef name, Nesting nesting = Nesting::Explicit);
   OpPassManager(OpPassManager &&rhs);
   OpPassManager(const OpPassManager &rhs);
   ~OpPassManager();
index c092d01..6fe6660 100644 (file)
@@ -107,6 +107,19 @@ std::unique_ptr<Pass> createPrintOpStatsPass();
 /// Creates a pass which inlines calls and callable operations as defined by
 /// the CallGraph.
 std::unique_ptr<Pass> createInlinerPass();
+/// Creates an instance of the inliner pass, and use the provided pass managers
+/// when optimizing callable operations with names matching the key type.
+/// Callable operations with a name not within the provided map will use the
+/// default inliner pipeline during optimization.
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines);
+/// Creates an instance of the inliner pass, and use the provided pass managers
+/// when optimizing callable operations with names matching the key type.
+/// Callable operations with a name not within the provided map will use the
+/// provided default pipeline builder.
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
+                  std::function<void(OpPassManager &)> defaultPipelineBuilder);
 
 /// Creates a pass which performs sparse conditional constant propagation over
 /// nested operations.
index afad7cd..438a468 100644 (file)
@@ -285,9 +285,12 @@ def Inliner : Pass<"inline"> {
   let summary = "Inline function calls";
   let constructor = "mlir::createInlinerPass()";
   let options = [
-    Option<"disableCanonicalization", "disable-simplify", "bool",
-           /*default=*/"false",
-           "Disable running simplifications during inlining">,
+    Option<"defaultPipelineStr", "default-pipeline", "std::string",
+           /*default=*/"", "The default optimizer pipeline used for callables">,
+    ListOption<"opPipelineStrs", "op-pipelines", "std::string",
+               "Callable operation specific optimizer pipelines (in the form "
+               "of `dialect.op(pipeline)`)",
+               "llvm::cl::MiscFlags::CommaSeparated">,
     Option<"maxInliningIterations", "max-iterations", "unsigned",
            /*default=*/"4",
            "Maximum number of iterations when inlining within an SCC">,
index f53a087..d9046be 100644 (file)
@@ -340,22 +340,25 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 
   // Initialize the pass state with a callback for the pass to dynamically
   // execute a pipeline on the currently visited operation.
-  auto dynamic_pipeline_callback =
-      [op, &am, verifyPasses](OpPassManager &pipeline,
-                              Operation *root) -> LogicalResult {
+  PassInstrumentor *pi = am.getPassInstrumentor();
+  PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
+                                                        pass};
+  auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
+                                       Operation *root) -> LogicalResult {
     if (!op->isAncestor(root))
       return root->emitOpError()
              << "Trying to schedule a dynamic pipeline on an "
                 "operation that isn't "
                 "nested under the current operation the pass is processing";
+    assert(pipeline.getOpName() == root->getName().getStringRef());
 
-    AnalysisManager nestedAm = am.nest(root);
+    AnalysisManager nestedAm = root == op ? am : am.nest(root);
     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
-                                          verifyPasses);
+                                          verifyPasses, pi, &parentInfo);
   };
   pass->passState.emplace(op, am, dynamic_pipeline_callback);
+
   // Instrument before the pass has run.
-  PassInstrumentor *pi = am.getPassInstrumentor();
   if (pi)
     pi->runBeforePass(pass, op);
 
@@ -388,7 +391,10 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 /// Run the given operation and analysis manager on a provided op pass manager.
 LogicalResult OpToOpPassAdaptor::runPipeline(
     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
-    AnalysisManager am, bool verifyPasses) {
+    AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor,
+    const PassInstrumentation::PipelineParentInfo *parentInfo) {
+  assert((!instrumentor || parentInfo) &&
+         "expected parent info if instrumentor is provided");
   auto scope_exit = llvm::make_scope_exit([&] {
     // Clear out any computed operation analyses. These analyses won't be used
     // any more in this pipeline, and this helps reduce the current working set
@@ -398,10 +404,13 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
   });
 
   // Run the pipeline over the provided operation.
+  if (instrumentor)
+    instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
   for (Pass &pass : passes)
     if (failed(run(&pass, op, am, verifyPasses)))
       return failure();
-
+  if (instrumentor)
+    instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
   return success();
 }
 
@@ -491,17 +500,10 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
                                        *op.getContext());
         if (!mgr)
           continue;
-        Identifier opName = mgr->getOpName(*getOperation()->getContext());
 
         // Run the held pipeline over the current operation.
-        if (instrumentor)
-          instrumentor->runBeforePipeline(opName, parentInfo);
-        LogicalResult result =
-            runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
-        if (instrumentor)
-          instrumentor->runAfterPipeline(opName, parentInfo);
-
-        if (failed(result))
+        if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
+                               verifyPasses, instrumentor, &parentInfo)))
           return signalPassFailure();
       }
     }
@@ -576,13 +578,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
               pms, it.first->getName().getIdentifier(), getContext());
           assert(pm && "expected valid pass manager for operation");
 
-          Identifier opName = pm->getOpName(*getOperation()->getContext());
-          if (instrumentor)
-            instrumentor->runBeforePipeline(opName, parentInfo);
-          auto pipelineResult =
-              runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
-          if (instrumentor)
-            instrumentor->runAfterPipeline(opName, parentInfo);
+          LogicalResult pipelineResult =
+              runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
+                          instrumentor, &parentInfo);
 
           // Drop this thread from being tracked by the diagnostic handler.
           // After this task has finished, the thread may be used outside of
@@ -848,22 +846,41 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
 // AnalysisManager
 //===----------------------------------------------------------------------===//
 
-/// Returns a pass instrumentation object for the current operation.
-PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
-  ParentPointerT curParent = parent;
-  while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
-    curParent = parentAM->parent;
-  return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
+/// Get an analysis manager for the given operation, which must be a proper
+/// descendant of the current operation represented by this analysis manager.
+AnalysisManager AnalysisManager::nest(Operation *op) {
+  Operation *currentOp = impl->getOperation();
+  assert(currentOp->isProperAncestor(op) &&
+         "expected valid descendant operation");
+
+  // Check for the base case where the provided operation is immediately nested.
+  if (currentOp == op->getParentOp())
+    return nestImmediate(op);
+
+  // Otherwise, we need to collect all ancestors up to the current operation.
+  SmallVector<Operation *, 4> opAncestors;
+  do {
+    opAncestors.push_back(op);
+    op = op->getParentOp();
+  } while (op != currentOp);
+
+  AnalysisManager result = *this;
+  for (Operation *op : llvm::reverse(opAncestors))
+    result = result.nestImmediate(op);
+  return result;
 }
 
-/// Get an analysis manager for the given child operation.
-AnalysisManager AnalysisManager::nest(Operation *op) {
+/// Get an analysis manager for the given immediately nested child operation.
+AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
+  assert(impl->getOperation() == op->getParentOp() &&
+         "expected immediate child operation");
+
   auto it = impl->childAnalyses.find(op);
   if (it == impl->childAnalyses.end())
     it = impl->childAnalyses
-             .try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
+             .try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
              .first;
-  return {this, it->second.get()};
+  return {it->second.get()};
 }
 
 /// Invalidate any non preserved analyses.
index d888d57..2533d87 100644 (file)
@@ -60,9 +60,11 @@ private:
 
   /// Run the given operation and analysis manager on a provided op pass
   /// manager.
-  static LogicalResult
-  runPipeline(iterator_range<OpPassManager::pass_iterator> passes,
-              Operation *op, AnalysisManager am, bool verifyPasses);
+  static LogicalResult runPipeline(
+      iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
+      AnalysisManager am, bool verifyPasses,
+      PassInstrumentor *instrumentor = nullptr,
+      const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr);
 
   /// A set of adaptors to run.
   SmallVector<OpPassManager, 1> mgrs;
index 78e40d5..50cbee8 100644 (file)
@@ -291,11 +291,15 @@ private:
 /// given to enable accurate error reporting.
 LogicalResult TextualPipeline::initialize(StringRef text,
                                           raw_ostream &errorStream) {
+  if (text.empty())
+    return success();
+
   // Build a source manager to use for error reporting.
   llvm::SourceMgr pipelineMgr;
-  pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(
-                                     text, "MLIR Textual PassPipeline Parser"),
-                                 llvm::SMLoc());
+  pipelineMgr.AddNewSourceBuffer(
+      llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
+                                       /*RequiresNullTerminator=*/false),
+      llvm::SMLoc());
   auto errorHandler = [&](const char *rawLoc, Twine msg) {
     pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
                              llvm::SourceMgr::DK_Error, msg);
@@ -327,7 +331,7 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
 
     // If we have a single terminating name, we're done.
-    if (pos == text.npos)
+    if (pos == StringRef::npos)
       break;
 
     text = text.substr(pos);
@@ -338,9 +342,19 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
       text = text.substr(1);
 
       // Skip over everything until the closing '}' and store as options.
-      size_t close = text.find('}');
+      size_t close = StringRef::npos;
+      for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
+        if (text[i] == '{') {
+          ++braceCount;
+          continue;
+        }
+        if (text[i] == '}' && --braceCount == 0) {
+          close = i;
+          break;
+        }
+      }
 
-      // TODO: Handle skipping over quoted sub-strings.
+      // Check to see if a closing options brace was found.
       if (close == StringRef::npos) {
         return errorHandler(
             /*rawLoc=*/text.data() - 1,
index e397875..4998875 100644 (file)
@@ -302,16 +302,13 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
 void PassTiming::runAfterPass(Pass *pass, Operation *) {
   Timer *timer = popLastActiveTimer();
 
-  // If this is a pass adaptor, then we need to merge in the timing data for the
-  // pipelines running on other threads.
-  if (isa<OpToOpPassAdaptor>(pass)) {
-    auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
-    if (toMerge != pipelinesToMerge.end()) {
-      for (auto &it : toMerge->second)
-        timer->mergeChild(std::move(it));
-      pipelinesToMerge.erase(toMerge);
-    }
-    return;
+  // Check to see if we need to merge in the timing data for the pipelines
+  // running on other threads.
+  auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
+  if (toMerge != pipelinesToMerge.end()) {
+    for (auto &it : toMerge->second)
+      timer->mergeChild(std::move(it));
+    pipelinesToMerge.erase(toMerge);
   }
 
   timer->stop();
index 64c7ca8..364af20 100644 (file)
@@ -15,9 +15,8 @@
 
 #include "PassDetail.h"
 #include "mlir/Analysis/CallGraph.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/SCCIterator.h"
 
 using namespace mlir;
 
+/// This function implements the default inliner optimization pipeline.
+static void defaultInlinerOptPipeline(OpPassManager &pm) {
+  pm.addPass(createCanonicalizerPass());
+}
+
 //===----------------------------------------------------------------------===//
 // Symbol Use Tracking
 //===----------------------------------------------------------------------===//
@@ -279,9 +283,9 @@ private:
 
 /// Run a given transformation over the SCCs of the callgraph in a bottom up
 /// traversal.
-static void
-runTransformOnCGSCCs(const CallGraph &cg,
-                     function_ref<void(CallGraphSCC &)> sccTransformer) {
+static LogicalResult runTransformOnCGSCCs(
+    const CallGraph &cg,
+    function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
   CallGraphSCC currentSCC(cgi);
   while (!cgi.isAtEnd()) {
@@ -289,8 +293,10 @@ runTransformOnCGSCCs(const CallGraph &cg,
     // SCC without invalidating our iterator.
     currentSCC.reset(*cgi);
     ++cgi;
-    sccTransformer(currentSCC);
+    if (failed(sccTransformer(currentSCC)))
+      return failure();
   }
+  return success();
 }
 
 namespace {
@@ -499,85 +505,94 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
   return success(inlinedAnyCalls);
 }
 
-/// Canonicalize the nodes within the given SCC with the given set of
-/// canonicalization patterns.
-static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
-                            CallGraphSCC &currentSCC, MLIRContext *context,
-                            const FrozenRewritePatternList &canonPatterns) {
-  // Collect the sets of nodes to canonicalize.
-  SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
-  for (auto *node : currentSCC) {
-    // Don't canonicalize the external node, it has no valid callable region.
-    if (node->isExternal())
-      continue;
-
-    // Don't canonicalize nodes with children. Nodes with children
-    // require special handling as we may remove the node during
-    // canonicalization. In the future, we should be able to handle this
-    // case with proper node deletion tracking.
-    if (node->hasChildren())
-      continue;
-
-    // We also won't apply canonicalizations for nodes that are not
-    // isolated. This avoids potentially mutating the regions of nodes defined
-    // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
-    // driver.
-    auto *region = node->getCallableRegion();
-    if (!region->getParentOp()->isKnownIsolatedFromAbove())
-      continue;
-    nodesToCanonicalize.push_back(node);
-  }
-  if (nodesToCanonicalize.empty())
-    return;
-
-  // Canonicalize each of the nodes within the SCC in parallel.
-  // NOTE: This is simple now, because we don't enable canonicalizing nodes
-  // within children. When we remove this restriction, this logic will need to
-  // be reworked.
-  if (context->isMultithreadingEnabled()) {
-    ParallelDiagnosticHandler canonicalizationHandler(context);
-    llvm::parallelForEachN(
-        /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
-          // Set the order for this thread so that diagnostics will be properly
-          // ordered.
-          canonicalizationHandler.setOrderIDForThread(index);
-
-          // Apply the canonicalization patterns to this region.
-          auto *node = nodesToCanonicalize[index];
-          applyPatternsAndFoldGreedily(*node->getCallableRegion(),
-                                       canonPatterns);
-
-          // Make sure to reset the order ID for the diagnostic handler, as this
-          // thread may be used in a different context.
-          canonicalizationHandler.eraseOrderIDForThread();
-        });
-  } else {
-    for (CallGraphNode *node : nodesToCanonicalize)
-      applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
-  }
-
-  // Recompute the uses held by each of the nodes.
-  for (CallGraphNode *node : nodesToCanonicalize)
-    useList.recomputeUses(node, cg);
-}
-
 //===----------------------------------------------------------------------===//
 // InlinerPass
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct InlinerPass : public InlinerBase<InlinerPass> {
+class InlinerPass : public InlinerBase<InlinerPass> {
+public:
+  InlinerPass();
+  InlinerPass(const InlinerPass &) = default;
+  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
+  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
+              llvm::StringMap<OpPassManager> opPipelines);
   void runOnOperation() override;
 
-  /// Attempt to inline calls within the given scc, and run canonicalizations
-  /// with the given patterns, until a fixed point is reached. This allows for
-  /// the inlining of newly devirtualized calls.
-  void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
-                 MLIRContext *context,
-                 const FrozenRewritePatternList &canonPatterns);
+private:
+  /// Attempt to inline calls within the given scc, and run simplifications,
+  /// until a fixed point is reached. This allows for the inlining of newly
+  /// devirtualized calls. Returns failure if there was a fatal error during
+  /// inlining.
+  LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
+                          CallGraphSCC &currentSCC, MLIRContext *context);
+
+  /// Optimize the nodes within the given SCC with one of the held optimization
+  /// pass pipelines. Returns failure if an error occurred during the
+  /// optimization of the SCC, success otherwise.
+  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
+                            CallGraphSCC &currentSCC, MLIRContext *context);
+
+  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
+  /// error occurred during the optimization of the SCC, success otherwise.
+  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
+                                 MLIRContext *context);
+
+  /// Optimize the given callable node with one of the pass managers provided
+  /// with `pipelines`, or the default pipeline. Returns failure if an error
+  /// occurred during the optimization of the callable, success otherwise.
+  LogicalResult optimizeCallable(CallGraphNode *node,
+                                 llvm::StringMap<OpPassManager> &pipelines);
+
+  /// Attempt to initialize the options of this pass from the given string.
+  /// Derived classes may override this method to hook into the point at which
+  /// options are initialized, but should generally always invoke this base
+  /// class variant.
+  LogicalResult initializeOptions(StringRef options) override;
+
+  /// An optional function that constructs a default optimization pipeline for
+  /// a given operation.
+  std::function<void(OpPassManager &)> defaultPipeline;
+  /// A map of operation names to pass pipelines to use when optimizing
+  /// callable operations of these types. This provides a specialized pipeline
+  /// instead of the default. The vector size is the number of threads used
+  /// during optimization.
+  SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
 };
 } // end anonymous namespace
 
+InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
+InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
+    : defaultPipeline(defaultPipeline) {
+  opPipelines.push_back({});
+
+  // Initialize the pass options with the provided arguments.
+  if (defaultPipeline) {
+    OpPassManager fakePM("__mlir_fake_pm_op");
+    defaultPipeline(fakePM);
+    llvm::raw_string_ostream strStream(defaultPipelineStr);
+    fakePM.printAsTextualPipeline(strStream);
+  }
+}
+
+InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
+                         llvm::StringMap<OpPassManager> opPipelines)
+    : InlinerPass(std::move(defaultPipeline)) {
+  if (opPipelines.empty())
+    return;
+
+  // Update the option for the op specific optimization pipelines.
+  for (auto &it : opPipelines) {
+    std::string pipeline;
+    llvm::raw_string_ostream pipelineOS(pipeline);
+    pipelineOS << it.getKey() << "(";
+    it.second.printAsTextualPipeline(pipelineOS);
+    pipelineOS << ")";
+    opPipelineStrs.addValue(pipeline);
+  }
+  this->opPipelines.emplace_back(std::move(opPipelines));
+}
+
 void InlinerPass::runOnOperation() {
   CallGraph &cg = getAnalysis<CallGraph>();
   auto *context = &getContext();
@@ -591,42 +606,190 @@ void InlinerPass::runOnOperation() {
     return signalPassFailure();
   }
 
-  // Collect a set of canonicalization patterns to use when simplifying
-  // callable regions within an SCC.
-  OwningRewritePatternList canonPatterns;
-  for (auto *op : context->getRegisteredOperations())
-    op->getCanonicalizationPatterns(canonPatterns, context);
-  FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns));
-
   // Run the inline transform in post-order over the SCCs in the callgraph.
   SymbolTableCollection symbolTable;
   Inliner inliner(context, cg, symbolTable);
   CGUseList useList(getOperation(), cg, symbolTable);
-  runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
-    inlineSCC(inliner, useList, scc, context, frozenCanonPatterns);
+  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
+    return inlineSCC(inliner, useList, scc, context);
   });
+  if (failed(result))
+    return signalPassFailure();
 
   // After inlining, make sure to erase any callables proven to be dead.
   inliner.eraseDeadCallables();
 }
 
-void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
-                            CallGraphSCC &currentSCC, MLIRContext *context,
-                            const FrozenRewritePatternList &canonPatterns) {
-  // If we successfully inlined any calls, run some simplifications on the
-  // nodes of the scc. Continue attempting to inline until we reach a fixed
-  // point, or a maximum iteration count. We canonicalize here as it may
-  // devirtualize new calls, as well as give us a better cost model.
+LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
+                                     CallGraphSCC &currentSCC,
+                                     MLIRContext *context) {
+  // Continuously simplify and inline until we either reach a fixed point, or
+  // hit the maximum iteration count. Simplifying early helps to refine the cost
+  // model, and in future iterations may devirtualize new calls.
   unsigned iterationCount = 0;
-  while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
-    // If we aren't allowing simplifications or the max iteration count was
-    // reached, then bail out early.
-    if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
+  do {
+    if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
+      return failure();
+    if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
       break;
-    canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
+  } while (++iterationCount < maxInliningIterations);
+  return success();
+}
+
+LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
+                                       CallGraphSCC &currentSCC,
+                                       MLIRContext *context) {
+  // Collect the sets of nodes to simplify.
+  SmallVector<CallGraphNode *, 4> nodesToVisit;
+  for (auto *node : currentSCC) {
+    if (node->isExternal())
+      continue;
+
+    // Don't simplify nodes with children. Nodes with children require special
+    // handling as we may remove the node during simplification. In the future,
+    // we should be able to handle this case with proper node deletion tracking.
+    if (node->hasChildren())
+      continue;
+
+    // We also won't apply simplifications to nodes that can't have passes
+    // scheduled on them.
+    auto *region = node->getCallableRegion();
+    if (!region->getParentOp()->isKnownIsolatedFromAbove())
+      continue;
+    nodesToVisit.push_back(node);
+  }
+  if (nodesToVisit.empty())
+    return success();
+
+  // Optimize each of the nodes within the SCC in parallel.
+  // NOTE: This is simple now, because we don't enable optimizing nodes within
+  // children. When we remove this restriction, this logic will need to be
+  // reworked.
+  if (context->isMultithreadingEnabled()) {
+    if (failed(optimizeSCCAsync(nodesToVisit, context)))
+      return failure();
+
+    // Otherwise, we are optimizing within a single thread.
+  } else {
+    for (CallGraphNode *node : nodesToVisit) {
+      if (failed(optimizeCallable(node, opPipelines[0])))
+        return failure();
+    }
+  }
+
+  // Recompute the uses held by each of the nodes.
+  for (CallGraphNode *node : nodesToVisit)
+    useList.recomputeUses(node, cg);
+  return success();
+}
+
+LogicalResult
+InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
+                              MLIRContext *context) {
+  // Ensure that there are enough pipeline maps for the optimizer to run in
+  // parallel.
+  size_t numThreads = llvm::hardware_concurrency().compute_thread_count();
+  if (opPipelines.size() != numThreads) {
+    // Reserve before resizing so that we can use a reference to the first
+    // element.
+    opPipelines.reserve(numThreads);
+    opPipelines.resize(numThreads, opPipelines.front());
+  }
+
+  // Ensure an analysis manager has been constructed for each of the nodes.
+  // This prevents thread races when running the nested pipelines.
+  for (CallGraphNode *node : nodesToVisit)
+    getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
+
+  // An index for the current node to optimize.
+  std::atomic<unsigned> nodeIt(0);
+
+  // Optimize the nodes of the SCC in parallel.
+  ParallelDiagnosticHandler optimizerHandler(context);
+  return llvm::parallelTransformReduce(
+      llvm::seq<size_t>(0, numThreads), success(),
+      [](LogicalResult lhs, LogicalResult rhs) {
+        return success(succeeded(lhs) && succeeded(rhs));
+      },
+      [&](size_t index) {
+        LogicalResult result = success();
+        for (auto e = nodesToVisit.size(); nodeIt < e && succeeded(result);) {
+          // Get the next available operation index.
+          unsigned nextID = nodeIt++;
+          if (nextID >= e)
+            break;
+
+          // Set the order for this thread so that diagnostics will be
+          // properly ordered, and reset after optimization has finished.
+          optimizerHandler.setOrderIDForThread(nextID);
+          result = optimizeCallable(nodesToVisit[nextID], opPipelines[index]);
+          optimizerHandler.eraseOrderIDForThread();
+        }
+        return result;
+      });
+}
+
+LogicalResult
+InlinerPass::optimizeCallable(CallGraphNode *node,
+                              llvm::StringMap<OpPassManager> &pipelines) {
+  Operation *callable = node->getCallableRegion()->getParentOp();
+  StringRef opName = callable->getName().getStringRef();
+  auto pipelineIt = pipelines.find(opName);
+  if (pipelineIt == pipelines.end()) {
+    // If a pipeline didn't exist, use the default if possible.
+    if (!defaultPipeline)
+      return success();
+
+    OpPassManager defaultPM(opName);
+    defaultPipeline(defaultPM);
+    pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
   }
+  return runPipeline(pipelineIt->second, callable);
+}
+
+LogicalResult InlinerPass::initializeOptions(StringRef options) {
+  if (failed(Pass::initializeOptions(options)))
+    return failure();
+
+  // Initialize the default pipeline builder to use the option string.
+  if (!defaultPipelineStr.empty()) {
+    std::string defaultPipelineCopy = defaultPipelineStr;
+    defaultPipeline = [=](OpPassManager &pm) {
+      parsePassPipeline(defaultPipelineCopy, pm);
+    };
+  } else if (defaultPipelineStr.getNumOccurrences()) {
+    defaultPipeline = nullptr;
+  }
+
+  // Initialize the op specific pass pipelines.
+  llvm::StringMap<OpPassManager> pipelines;
+  for (StringRef pipeline : opPipelineStrs) {
+    // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
+    size_t pipelineStart = pipeline.find_first_of('(');
+    if (pipelineStart == StringRef::npos || !pipeline.consume_back(")"))
+      return failure();
+    StringRef opName = pipeline.take_front(pipelineStart);
+    OpPassManager pm(opName);
+    if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
+      return failure();
+    pipelines.try_emplace(opName, std::move(pm));
+  }
+  opPipelines.assign({std::move(pipelines)});
+
+  return success();
 }
 
 std::unique_ptr<Pass> mlir::createInlinerPass() {
   return std::make_unique<InlinerPass>();
 }
+std::unique_ptr<Pass>
+mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
+  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
+                                       std::move(opPipelines));
+}
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
+                  std::function<void(OpPassManager &)> defaultPipelineBuilder) {
+  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
+                                       std::move(opPipelines));
+}
index 5879acd..173e48c 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -inline="disable-simplify" | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -inline="default-pipeline=''" | FileCheck %s
 
 // Basic test that functions within affine operations are inlined.
 func @func_with_affine_ops(%N: index) {
index 36b1f8f..983fc26 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{disable-simplify})' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{default-pipeline=''})' | FileCheck %s
 
 spv.module Logical GLSL450 {
   spv.func @callee() "None" {
index 9e0945b..a1ba9cc 100644 (file)
@@ -20,9 +20,9 @@ module @inner_mod1 {
 // CHECK: Dump Before CSE
 // NOTNESTED-NEXT: @inner_mod1
 // NESTED-NEXT: @foo
-  func private @foo()
+  module @foo {}
 // Only in the nested case we have a second run of the pass here.
 // NESTED: Dump Before CSE
 // NESTED-NEXT: @baz
-  func private @baz()
+  module @baz {}
 }
index be9aa9c..d568be0 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -inline="disable-simplify" | FileCheck %s
-// RUN: mlir-opt %s -inline="disable-simplify" -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
+// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
+// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
 // RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
 
 // Inline a function that takes an argument.
index 57c5a59..a6a83dd 100644 (file)
@@ -35,15 +35,17 @@ public:
   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
 
   void runOnOperation() override {
+    Operation *currentOp = getOperation();
+
     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
-                 << getOperation()->getName() << "\n";
+                 << currentOp->getName() << "\n";
     if (pipeline.empty()) {
       llvm::errs() << "Empty pipeline\n";
       return;
     }
-    auto symbolOp = dyn_cast<SymbolOpInterface>(getOperation());
+    auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
     if (!symbolOp) {
-      getOperation()->emitWarning()
+      currentOp->emitWarning()
           << "Ignoring because not implementing SymbolOpInterface\n";
       return;
     }
@@ -54,24 +56,24 @@ public:
       return;
     }
     if (!pm) {
-      pm = std::make_unique<OpPassManager>(
-          getOperation()->getName().getIdentifier(),
-          OpPassManager::Nesting::Implicit);
+      pm = std::make_unique<OpPassManager>(currentOp->getName().getIdentifier(),
+                                           OpPassManager::Nesting::Implicit);
       parsePassPipeline(pipeline, *pm, llvm::errs());
     }
 
     // Check that running on the parent operation always immediately fails.
     if (runOnParent) {
-      if (getOperation()->getParentOp())
-        if (!failed(runPipeline(*pm, getOperation()->getParentOp())))
+      if (currentOp->getParentOp())
+        if (!failed(runPipeline(*pm, currentOp->getParentOp())))
           signalPassFailure();
       return;
     }
 
     if (runOnNestedOp) {
       llvm::errs() << "Run on nested op\n";
-      getOperation()->walk([&](Operation *op) {
-        if (op == getOperation() || !op->isKnownIsolatedFromAbove())
+      currentOp->walk([&](Operation *op) {
+        if (op == currentOp || !op->isKnownIsolatedFromAbove() ||
+            op->getName() != currentOp->getName())
           return;
         llvm::errs() << "Run on " << *op << "\n";
         // Run on the current operation
@@ -80,7 +82,7 @@ public:
       });
     } else {
       // Run on the current operation
-      if (failed(runPipeline(*pm, getOperation())))
+      if (failed(runPipeline(*pm, currentOp)))
         signalPassFailure();
     }
   }