[mlir][TilingInterface] NFC: Consolidate yield handling.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 6 Dec 2022 06:07:05 +0000 (06:07 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 16 Jan 2023 05:03:41 +0000 (05:03 +0000)
Add a new utility method to yield the tiled value as well as
preserving destination passing style.

Differential Revision: https://reviews.llvm.org/D139392

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

index 02323c5..52cd760 100644 (file)
@@ -173,7 +173,7 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
 /// }
 /// ```
 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
-static FailureOr<SmallVector<Value>>
+static SmallVector<Value>
 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
                  ValueRange yieldedValues,
                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
@@ -245,6 +245,27 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
   }
 }
 
+/// Helper method to yield the values of the tiled op, as well as
+/// update the destination operands of the tiled op, if it is
+/// a destination passing style op.
+static SmallVector<Value>
+yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
+                 Operation *tiledOp,
+                 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
+                 ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
+                 MutableArrayRef<scf::ForOp> loops) {
+  SmallVector<Value> replacements =
+      yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
+                       tileOffsetsList, tileSizesList, loops);
+  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
+    auto innerMostLoop = loops.back();
+    SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
+    updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
+                                        innerMostLoop.getRegionIterArgs());
+  }
+  return replacements;
+}
+
 /// Implementation of tiling transformation of `op` that implements the
 /// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>
@@ -258,12 +279,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
         op, "missing tile size computation function");
   }
 
-  // Get destination tensors.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors)))
-    return rewriter.notifyMatchFailure(op, "failed to get destinations");
-
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
   size_t numLoops = iterationDomain.size();
@@ -362,24 +377,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     }
   }
 
-  FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
-      rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
-      resultOffsetsList, resultSizesList, tilingResult.loops);
-  if (failed(replacementOr))
-    return rewriter.notifyMatchFailure(op, "failed to yield replacement");
-
-  if (auto dstOp =
-          dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
-    auto innerMostLoop = tilingResult.loops.back();
-    SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
-    assert(destinationTensors.size() ==
-               innerMostLoop.getRegionIterArgs().size() &&
-           "unexpected number of outputs");
-    updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
-                                        innerMostLoop.getRegionIterArgs());
-  }
+  SmallVector<Value> destinationTensors;
+  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+                                             destinationTensors)))
+    return rewriter.notifyMatchFailure(op, "failed to get destinations");
 
-  tilingResult.replacements = *replacementOr;
+  tilingResult.replacements = yieldTiledValues(
+      rewriter, destinationTensors, tilingResult.tiledOps.back(),
+      resultOffsetsList, resultSizesList, tilingResult.loops);
 
   LLVM_DEBUG({
     if (!tilingResult.loops.empty()) {
@@ -449,11 +454,9 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
     resultSizesList.push_back(
         b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
+  SmallVector<Value> replacements = yieldTiledValues(
       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
       resultSizesList, loops);
-  if (failed(replacementOr))
-    return b.notifyMatchFailure(op, "failed to yield replacement");
 
   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
   auto innerMostLoop = loops.back();
@@ -466,7 +469,7 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
 
   // 4. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
-  Operation *mergeOp = op.mergeReductions(b, loc, *replacementOr, reductionDim);
+  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
   b.replaceOp(op, mergeOp->getResults());
 
   SCFReductionTilingResult results;