[mlir][tiling] Relax tiling to accept generating multiple operations.
authorHanhan Wang <hanchung@google.com>
Fri, 4 Nov 2022 20:58:59 +0000 (13:58 -0700)
committerHanhan Wang <hanchung@google.com>
Fri, 4 Nov 2022 20:59:24 +0000 (13:59 -0700)
Some operations need to generate multiple operations when implementing
the tiling interface. Here is a sound example in IREE, see
https://github.com/iree-org/iree/pull/10905 for more details.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D137300

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index 9fa4114..151993c 100644 (file)
@@ -62,8 +62,10 @@ struct SCFTilingOptions {
 
 /// Transformation information returned after tiling.
 struct SCFTilingResult {
-  /// The tiled operation generated.
-  Operation *tiledOp;
+  /// Tiled operations that are generated during tiling. The order does not
+  /// matter except the last op. The replacements are expected to be the results
+  /// of the last op.
+  SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<scf::ForOp> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
index a35dd14..6b8ca91 100644 (file)
@@ -931,7 +931,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
   if (failed(maybeTilingResult))
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
-  results.push_back(maybeTilingResult->tiledOp);
+  results.append(maybeTilingResult->tiledOps);
   return DiagnosedSilenceableFailure(success());
 }
 
@@ -1251,7 +1251,7 @@ transform::TileOp::apply(TransformResults &transformResults,
       rewriter.replaceOp(linalgOp,
                          maybeTilingResult->loops.front()->getResults());
 
-    tiled.push_back(maybeTilingResult->tiledOp);
+    tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
       loops[en2.index()].push_back(en2.value());
   }
@@ -1609,7 +1609,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
 
     rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements);
 
-    tiled.push_back(tilingResult->tiledOp);
+    tiled.append(tilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(tilingResult->loops))
       loops[en2.index()].push_back(en2.value());
   }
index 0c86bd4..6e59bdb 100644 (file)
@@ -360,11 +360,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
         tilingResult.loops.back().getBody()->getTerminator());
   SmallVector<Operation *> tiledImplementation =
       op.getTiledImplementation(rewriter, offsets, sizes);
-  if (tiledImplementation.size() != 1) {
-    return rewriter.notifyMatchFailure(
-        op, "expected tiled implementation to return a single op");
-  }
-  tilingResult.tiledOp = tiledImplementation[0];
+  tilingResult.tiledOps.append(tiledImplementation);
   if (op->getNumResults() == 0) {
     // nothing more to do.
     return tilingResult;
@@ -396,13 +392,13 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   }
 
   FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
-      rewriter, destinationTensors, tilingResult.tiledOp->getResults(),
+      rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
       resultOffsetsList, resultSizesList, tilingResult.loops);
   if (failed(replacementOr))
     return rewriter.notifyMatchFailure(op, "failed to yield replacement");
 
   if (auto dstOp =
-          dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
+          dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
     auto innerMostLoop = tilingResult.loops.back();
     SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
     assert(destinationTensors.size() ==
@@ -554,13 +550,14 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
     if (failed(tilingResult))
       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
-    tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
+    for (auto tiledOp : tilingResult->tiledOps)
+      tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
     tileAndFuseResult.loops = std::move(tilingResult->loops);
     for (const auto &result : llvm::enumerate(
              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
       tileAndFuseResult.replacements[std::get<0>(result.value())] =
           std::get<1>(result.value());
-      yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
+      yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
           result.index())] = result.index();
     }
   }
index 31e3c1a..1644179 100644 (file)
@@ -193,7 +193,8 @@ struct TestTileUsingSCFForOp
       rewriter.eraseOp(op);
     }
 
-    filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+    for (auto tiledOp : tilingResult->tiledOps)
+      filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
     return success();
   }