// CallGraph traversal
//===----------------------------------------------------------------------===//
+namespace {
+/// This class represents a specific callgraph SCC.
+class CallGraphSCC {
+public:
+ CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
+ : parentIterator(parentIterator) {}
+ /// Return a range over the nodes within this SCC.
+ std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
+ std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
+
+ /// Reset the nodes of this SCC with those provided.
+ void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
+
+ /// Remove the given node from this SCC.
+ void remove(CallGraphNode *node) {
+ auto it = llvm::find(nodes, node);
+ if (it != nodes.end()) {
+ nodes.erase(it);
+ parentIterator.ReplaceNode(node, nullptr);
+ }
+ }
+
+private:
+ std::vector<CallGraphNode *> nodes;
+ llvm::scc_iterator<const CallGraph *> &parentIterator;
+};
+} // end anonymous namespace
+
/// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal.
-static void runTransformOnCGSCCs(
- const CallGraph &cg,
- function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
- std::vector<CallGraphNode *> currentSCCVec;
- auto cgi = llvm::scc_begin(&cg);
+static void
+runTransformOnCGSCCs(const CallGraph &cg,
+ function_ref<void(CallGraphSCC &)> sccTransformer) {
+ llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
+ CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) {
// Copy the current SCC and increment so that the transformer can modify the
// SCC without invalidating our iterator.
- currentSCCVec = *cgi;
+ currentSCC.reset(*cgi);
++cgi;
- sccTransformer(currentSCCVec);
+ sccTransformer(currentSCC);
}
}
/*traverseNestedCGNodes=*/true);
}
+ /// Mark the given callgraph node for deletion.
+ void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
+
+ /// This method properly disposes of callables that became dead during
+ /// inlining. This should not be called while iterating over the SCCs.
+ void eraseDeadCallables() {
+ for (CallGraphNode *node : deadNodes)
+ node->getCallableRegion()->getParentOp()->erase();
+ }
+
+ /// The set of callables known to be dead.
+ SmallPtrSet<CallGraphNode *, 8> deadNodes;
+
/// The current set of call instructions to consider for inlining.
SmallVector<ResolvedCall, 8> calls;
return true;
}
-/// Delete the given node and remove it from the current scc and the callgraph.
-static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
- MutableArrayRef<CallGraphNode *> currentSCC) {
- // Erase the parent operation and remove it from the various lists.
- node->getCallableRegion()->getParentOp()->erase();
- cg.eraseNode(node);
-
- // Replace this node in the currentSCC with the external node.
- auto it = llvm::find(currentSCC, node);
- if (it != currentSCC.end())
- *it = cg.getExternalNode();
-}
-
/// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise.
-static LogicalResult
-inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC) {
+static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
+ CallGraphSCC ¤tSCC) {
CallGraph &cg = inliner.cg;
auto &calls = inliner.calls;
+ // A set of dead nodes to remove after inlining.
+ SmallVector<CallGraphNode *, 1> deadNodes;
+
// Collect all of the direct calls within the nodes of the current SCC. We
// don't traverse nested callgraph nodes, because they are handled separately
// likely within a different SCC.
if (node->isExternal())
continue;
- // If this node is dead, just delete it now.
+ // Don't collect calls if the node is already dead.
if (useList.isDead(node))
- deleteNode(node, useList, cg, currentSCC);
+ deadNodes.push_back(node);
else
collectCallOps(*node->getCallableRegion(), node, cg, calls,
/*traverseNestedCGNodes=*/false);
}
- if (calls.empty())
- return failure();
-
- // A set of dead nodes to remove after inlining.
- SmallVector<CallGraphNode *, 1> deadNodes;
// Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining.
}
}
- for (CallGraphNode *node : deadNodes)
- deleteNode(node, useList, cg, currentSCC);
+ for (CallGraphNode *node : deadNodes) {
+ currentSCC.remove(node);
+ inliner.markForDeletion(node);
+ }
calls.clear();
return success(inlinedAnyCalls);
}
/// Canonicalize the nodes within the given SCC with the given set of
/// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
- MLIRContext *context,
+ CallGraphSCC ¤tSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// Collect the sets of nodes to canonicalize.
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
/// Attempt to inline calls within the given scc, and run canonicalizations
/// with the given patterns, until a fixed point is reached. This allows for
/// the inlining of newly devirtualized calls.
- void inlineSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
+ void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns);
};
// Run the inline transform in post-order over the SCCs in the callgraph.
Inliner inliner(context, cg);
CGUseList useList(getOperation(), cg);
- runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
+ runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns);
});
+
+ // After inlining, make sure to erase any callables proven to be dead.
+ inliner.eraseDeadCallables();
}
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
- MLIRContext *context,
+ CallGraphSCC ¤tSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// If we successfully inlined any calls, run some simplifications on the
// nodes of the scc. Continue attempting to inline until we reach a fixed