[mlir] Drop reliance of SliceAnalysis on specific ops.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 15 Feb 2021 21:44:44 +0000 (21:44 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 16 Feb 2021 06:34:32 +0000 (06:34 +0000)
SliceAnalysis originally was developed in the context of affine.for within mlfunc.
It predates the notion of region.
This revision updates it to not hardcode specific ops like scf::ForOp.
When rooted at an op, the behavior of the slice computation changes as it recurses into the regions of the op. This does not support gathering all values transitively depending on a loop induction variable anymore.
Additional variants rooted at a Value are added to also support the existing behavior.

Differential revision: https://reviews.llvm.org/D96702

mlir/include/mlir/Analysis/SliceAnalysis.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp

index 8d1bff5..f418684 100644 (file)
 namespace mlir {
 
 class Operation;
+class Value;
 
 /// Type of the condition to limit the propagation of transitive use-defs.
 /// This can be used in particular to limit the propagation to a given Scope or
 /// to avoid passing through certain types of operation in a configurable
 /// manner.
-using TransitiveFilter = std::function<bool(Operation *)>;
+using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
 
 /// Fills `forwardSlice` with the computed forward slice (i.e. all
 /// the transitive uses of op), **without** including that operation.
@@ -67,10 +68,13 @@ using TransitiveFilter = std::function<bool(Operation *)>;
 /// 2. reversing the result of 1. gives:
 ///      {4, 3, 6, 2, 1, 5, 8, 7, 9}
 ///
-void getForwardSlice(
-    Operation *op, llvm::SetVector<Operation *> *forwardSlice,
-    TransitiveFilter filter = /* pass-through*/
-    [](Operation *) { return true; });
+void getForwardSlice(Operation *op, llvm::SetVector<Operation *> *forwardSlice,
+                     TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getForwardSlice`. Return the union of all forward
+/// slices for the uses of the value `root`.
+void getForwardSlice(Value root, llvm::SetVector<Operation *> *forwardSlice,
+                     TransitiveFilter filter = nullptr /* pass-through*/);
 
 /// Fills `backwardSlice` with the computed backward slice (i.e.
 /// all the transitive defs of op), **without** including that operation.
@@ -106,10 +110,14 @@ void getForwardSlice(
 /// Assuming all local orders match the numbering order:
 ///    {1, 2, 5, 3, 4, 6}
 ///
-void getBackwardSlice(
-    Operation *op, llvm::SetVector<Operation *> *backwardSlice,
-    TransitiveFilter filter = /* pass-through*/
-    [](Operation *) { return true; });
+void getBackwardSlice(Operation *op,
+                      llvm::SetVector<Operation *> *backwardSlice,
+                      TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
+/// slices for the op defining or owning the value `root`.
+void getBackwardSlice(Value root, llvm::SetVector<Operation *> *backwardSlice,
+                      TransitiveFilter filter = nullptr /* pass-through*/);
 
 /// Iteratively computes backward slices and forward slices until
 /// a fixed point is reached. Returns an `llvm::SetVector<Operation *>` which
@@ -188,12 +196,10 @@ void getBackwardSlice(
 /// and keep things ordered but this is still hand-wavy and not worth the
 /// trouble for now: punt to a simple worklist-based solution.
 ///
-llvm::SetVector<Operation *> getSlice(
-    Operation *op,
-    TransitiveFilter backwardFilter = /* pass-through*/
-    [](Operation *) { return true; },
-    TransitiveFilter forwardFilter = /* pass-through*/
-    [](Operation *) { return true; });
+llvm::SetVector<Operation *>
+getSlice(Operation *op,
+         TransitiveFilter backwardFilter = nullptr /* pass-through*/,
+         TransitiveFilter forwardFilter = nullptr /* pass-through*/);
 
 /// Multi-root DAG topological sort.
 /// Performs a topological sort of the Operation in the `toSort` SetVector.
index 07cbca8..47d1258 100644 (file)
@@ -30,36 +30,24 @@ using llvm::SetVector;
 static void getForwardSliceImpl(Operation *op,
                                 SetVector<Operation *> *forwardSlice,
                                 TransitiveFilter filter) {
-  if (!op) {
+  if (!op)
     return;
-  }
 
   // Evaluate whether we should keep this use.
   // This is useful in particular to implement scoping; i.e. return the
   // transitive forwardSlice in the current scope.
-  if (!filter(op)) {
+  if (filter && !filter(op))
     return;
-  }
 
-  if (auto forOp = dyn_cast<AffineForOp>(op)) {
-    for (Operation *userOp : forOp.getInductionVar().getUsers())
+  for (Region &region : op->getRegions())
+    for (Block &block : region)
+      for (Operation &blockOp : block)
+        if (forwardSlice->count(&blockOp) == 0)
+          getForwardSliceImpl(&blockOp, forwardSlice, filter);
+  for (Value result : op->getResults()) {
+    for (Operation *userOp : result.getUsers())
       if (forwardSlice->count(userOp) == 0)
         getForwardSliceImpl(userOp, forwardSlice, filter);
-  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
-    for (Operation *userOp : forOp.getInductionVar().getUsers())
-      if (forwardSlice->count(userOp) == 0)
-        getForwardSliceImpl(userOp, forwardSlice, filter);
-    for (Value result : forOp.getResults())
-      for (Operation *userOp : result.getUsers())
-        if (forwardSlice->count(userOp) == 0)
-          getForwardSliceImpl(userOp, forwardSlice, filter);
-  } else {
-    assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
-    for (Value result : op->getResults()) {
-      for (Operation *userOp : result.getUsers())
-        if (forwardSlice->count(userOp) == 0)
-          getForwardSliceImpl(userOp, forwardSlice, filter);
-    }
   }
 
   forwardSlice->insert(op);
@@ -79,45 +67,47 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
   forwardSlice->insert(v.rbegin(), v.rend());
 }
 
+void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
+                           TransitiveFilter filter) {
+  for (Operation *user : root.getUsers())
+    getForwardSliceImpl(user, forwardSlice, filter);
+
+  // Reverse to get back the actual topological order.
+  // std::reverse does not work out of the box on SetVector and I want an
+  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
+  std::vector<Operation *> v(forwardSlice->takeVector());
+  forwardSlice->insert(v.rbegin(), v.rend());
+}
+
 static void getBackwardSliceImpl(Operation *op,
                                  SetVector<Operation *> *backwardSlice,
                                  TransitiveFilter filter) {
-  if (!op)
+  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
     return;
 
-  assert((op->getNumRegions() == 0 ||
-          isa<AffineForOp, scf::ForOp, linalg::LinalgOp, linalg::PadTensorOp>(
-              op)) &&
-         "unexpected generic op with regions");
-
   // Evaluate whether we should keep this def.
   // This is useful in particular to implement scoping; i.e. return the
-  // transitive forwardSlice in the current scope.
-  if (!filter(op)) {
+  // transitive backwardSlice in the current scope.
+  if (filter && !filter(op))
     return;
-  }
 
   for (auto en : llvm::enumerate(op->getOperands())) {
     auto operand = en.value();
-    if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
-      if (auto affIv = getForInductionVarOwner(operand)) {
-        auto *affOp = affIv.getOperation();
-        if (backwardSlice->count(affOp) == 0)
-          getBackwardSliceImpl(affOp, backwardSlice, filter);
-      } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
-        auto *loopOp = loopIv.getOperation();
-        if (backwardSlice->count(loopOp) == 0)
-          getBackwardSliceImpl(loopOp, backwardSlice, filter);
-      } else if (blockArg.getOwner() !=
-                 &op->getParentOfType<FuncOp>().getBody().front()) {
-        op->emitError("unsupported CF for operand ") << en.index();
-        llvm_unreachable("Unsupported control flow");
-      }
-      continue;
-    }
-    auto *op = operand.getDefiningOp();
-    if (backwardSlice->count(op) == 0) {
-      getBackwardSliceImpl(op, backwardSlice, filter);
+    if (auto *definingOp = operand.getDefiningOp()) {
+      if (backwardSlice->count(definingOp) == 0)
+        getBackwardSliceImpl(definingOp, backwardSlice, filter);
+    } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
+      Block *block = blockArg.getOwner();
+      Operation *parentOp = block->getParentOp();
+      // TODO: determine whether we want to recurse backward into the other
+      // blocks of parentOp, which are not technically backward unless they flow
+      // into us. For now, just bail.
+      assert(parentOp->getNumRegions() == 1 &&
+             parentOp->getRegion(0).getBlocks().size() == 1);
+      if (backwardSlice->count(parentOp) == 0)
+        getBackwardSliceImpl(parentOp, backwardSlice, filter);
+    } else {
+      llvm_unreachable("No definingOp and not a block argument.");
     }
   }
 
@@ -134,6 +124,16 @@ void mlir::getBackwardSlice(Operation *op,
   backwardSlice->remove(op);
 }
 
+void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
+                            TransitiveFilter filter) {
+  if (Operation *definingOp = root.getDefiningOp()) {
+    getBackwardSlice(definingOp, backwardSlice, filter);
+    return;
+  }
+  Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
+  getBackwardSlice(bbAargOwner, backwardSlice, filter);
+}
+
 SetVector<Operation *> mlir::getSlice(Operation *op,
                                       TransitiveFilter backwardFilter,
                                       TransitiveFilter forwardFilter) {
index f3d98f6..c3bc73a 100644 (file)
@@ -243,7 +243,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
                         << "\n");
 
       llvm::SetVector<Operation *> forwardSlice;
-      getForwardSlice(transferRead, &forwardSlice);
+      getForwardSlice(transferRead.getOperation(), &forwardSlice);
 
       // Look for the last TransferWriteOp in the forwardSlice of
       // `transferRead` that operates on the same memref.
@@ -381,9 +381,10 @@ hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
   // Get the backwards slice from `padTensorOp` that is dominated by the
   // outermost enclosing loop.
   DominanceInfo domInfo(outermostEnclosingForOp);
-  getBackwardSlice(padTensorOp, &backwardSlice, [&](Operation *op) {
-    return domInfo.dominates(outermostEnclosingForOp, op);
-  });
+  getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
+                   [&](Operation *op) {
+                     return domInfo.dominates(outermostEnclosingForOp, op);
+                   });
 
   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
   if (llvm::any_of(backwardSlice, [](Operation *op) {
index d8bc6e0..a8c32c8 100644 (file)
@@ -1830,9 +1830,9 @@ Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
 // Return failure when any op fails to hoist.
 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
   SetVector<Operation *> forwardSlice;
-  getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
-    return op != inner.getOperation();
-  });
+  getForwardSlice(
+      outer.getInductionVar(), &forwardSlice,
+      [&inner](Operation *op) { return op != inner.getOperation(); });
   LogicalResult status = success();
   SmallVector<Operation *, 8> toHoist;
   for (auto &op : outer.getBody()->without_terminator()) {
@@ -1844,8 +1844,8 @@ static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
       status = failure();
       continue;
     }
-    // Skip scf::ForOp, these are not considered a failure.
-    if (op.getNumRegions() > 0)
+    // Skip intermediate scf::ForOp, these are not considered a failure.
+    if (isa<scf::ForOp>(op))
       continue;
     // Skip other ops with regions.
     if (op.getNumRegions() > 0) {