namespace mlir {
class Operation;
+class Value;
/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope or
/// to avoid passing through certain types of operation in a configurable
/// manner.
-using TransitiveFilter = std::function<bool(Operation *)>;
+using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of op), **without** including that operation.
/// 2. reversing the result of 1. gives:
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
-void getForwardSlice(
- Operation *op, llvm::SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter = /* pass-through*/
- [](Operation *) { return true; });
+void getForwardSlice(Operation *op, llvm::SetVector<Operation *> *forwardSlice,
+ TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getForwardSlice`. Return the union of all forward
+/// slices for the uses of the value `root`.
+void getForwardSlice(Value root, llvm::SetVector<Operation *> *forwardSlice,
+ TransitiveFilter filter = nullptr /* pass-through*/);
/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of op), **without** including that operation.
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 3, 4, 6}
///
-void getBackwardSlice(
- Operation *op, llvm::SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter = /* pass-through*/
- [](Operation *) { return true; });
+void getBackwardSlice(Operation *op,
+ llvm::SetVector<Operation *> *backwardSlice,
+ TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
+/// slices for the op defining or owning the value `root`.
+void getBackwardSlice(Value root, llvm::SetVector<Operation *> *backwardSlice,
+ TransitiveFilter filter = nullptr /* pass-through*/);
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `llvm::SetVector<Operation *>` which
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
-llvm::SetVector<Operation *> getSlice(
- Operation *op,
- TransitiveFilter backwardFilter = /* pass-through*/
- [](Operation *) { return true; },
- TransitiveFilter forwardFilter = /* pass-through*/
- [](Operation *) { return true; });
+llvm::SetVector<Operation *>
+getSlice(Operation *op,
+ TransitiveFilter backwardFilter = nullptr /* pass-through*/,
+ TransitiveFilter forwardFilter = nullptr /* pass-through*/);
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
static void getForwardSliceImpl(Operation *op,
SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
- if (!op) {
+ if (!op)
return;
- }
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
- if (!filter(op)) {
+ if (filter && !filter(op))
return;
- }
- if (auto forOp = dyn_cast<AffineForOp>(op)) {
- for (Operation *userOp : forOp.getInductionVar().getUsers())
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice->count(&blockOp) == 0)
+ getForwardSliceImpl(&blockOp, forwardSlice, filter);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
if (forwardSlice->count(userOp) == 0)
getForwardSliceImpl(userOp, forwardSlice, filter);
- } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
- for (Operation *userOp : forOp.getInductionVar().getUsers())
- if (forwardSlice->count(userOp) == 0)
- getForwardSliceImpl(userOp, forwardSlice, filter);
- for (Value result : forOp.getResults())
- for (Operation *userOp : result.getUsers())
- if (forwardSlice->count(userOp) == 0)
- getForwardSliceImpl(userOp, forwardSlice, filter);
- } else {
- assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
- for (Value result : op->getResults()) {
- for (Operation *userOp : result.getUsers())
- if (forwardSlice->count(userOp) == 0)
- getForwardSliceImpl(userOp, forwardSlice, filter);
- }
}
forwardSlice->insert(op);
forwardSlice->insert(v.rbegin(), v.rend());
}
+void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
+ TransitiveFilter filter) {
+ for (Operation *user : root.getUsers())
+ getForwardSliceImpl(user, forwardSlice, filter);
+
+ // Reverse to get back the actual topological order.
+ // std::reverse does not work out of the box on SetVector and I want an
+ // in-place swap based thing (the real std::reverse, not the LLVM adapter).
+ std::vector<Operation *> v(forwardSlice->takeVector());
+ forwardSlice->insert(v.rbegin(), v.rend());
+}
+
static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
- if (!op)
+ if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
- assert((op->getNumRegions() == 0 ||
- isa<AffineForOp, scf::ForOp, linalg::LinalgOp, linalg::PadTensorOp>(
- op)) &&
- "unexpected generic op with regions");
-
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
- // transitive forwardSlice in the current scope.
- if (!filter(op)) {
+ // transitive backwardSlice in the current scope.
+ if (filter && !filter(op))
return;
- }
for (auto en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
- if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
- if (auto affIv = getForInductionVarOwner(operand)) {
- auto *affOp = affIv.getOperation();
- if (backwardSlice->count(affOp) == 0)
- getBackwardSliceImpl(affOp, backwardSlice, filter);
- } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
- auto *loopOp = loopIv.getOperation();
- if (backwardSlice->count(loopOp) == 0)
- getBackwardSliceImpl(loopOp, backwardSlice, filter);
- } else if (blockArg.getOwner() !=
- &op->getParentOfType<FuncOp>().getBody().front()) {
- op->emitError("unsupported CF for operand ") << en.index();
- llvm_unreachable("Unsupported control flow");
- }
- continue;
- }
- auto *op = operand.getDefiningOp();
- if (backwardSlice->count(op) == 0) {
- getBackwardSliceImpl(op, backwardSlice, filter);
+ if (auto *definingOp = operand.getDefiningOp()) {
+ if (backwardSlice->count(definingOp) == 0)
+ getBackwardSliceImpl(definingOp, backwardSlice, filter);
+ } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
+ Block *block = blockArg.getOwner();
+ Operation *parentOp = block->getParentOp();
+ // TODO: determine whether we want to recurse backward into the other
+ // blocks of parentOp, which are not technically backward unless they flow
+ // into us. For now, just bail.
+ assert(parentOp->getNumRegions() == 1 &&
+ parentOp->getRegion(0).getBlocks().size() == 1);
+ if (backwardSlice->count(parentOp) == 0)
+ getBackwardSliceImpl(parentOp, backwardSlice, filter);
+ } else {
+ llvm_unreachable("No definingOp and not a block argument.");
}
}
backwardSlice->remove(op);
}
+void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
+ TransitiveFilter filter) {
+ if (Operation *definingOp = root.getDefiningOp()) {
+ getBackwardSlice(definingOp, backwardSlice, filter);
+ return;
+ }
+ Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
+ getBackwardSlice(bbAargOwner, backwardSlice, filter);
+}
+
SetVector<Operation *> mlir::getSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {