Changes to `SCFFuseProducerOfSliceResult` to also return the operations created durin...
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 20 Mar 2023 18:58:39 +0000 (18:58 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 20 Mar 2023 20:55:48 +0000 (20:55 +0000)
This is follow up to https://reviews.llvm.org/D145133 that allows
propogating information about ops that are fused back to the caller.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D146254

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

index 5e03ecc..e7bcd06 100644 (file)
@@ -96,6 +96,7 @@ struct SCFTileAndFuseOptions {
 struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
+  SmallVector<Operation *> tiledOps;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
index 6706f54..ec116df 100644 (file)
@@ -604,7 +604,8 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0]};
+                                           tileAndFuseResult->tiledValues[0],
+                                           tileAndFuseResult->tiledOps};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -612,7 +613,8 @@ void mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<scf::ForOp> loops) {
-  auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
+  auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
+      fusedProducerInfo;
   SmallVector<Value> initValues;
   FailureOr<Value> initValue = tensor::getOrCreateDestination(
       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
@@ -623,8 +625,11 @@ void mlir::scf::yieldReplacementForFusedProducer(
         yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
                          resultOffsets, resultSizes, loops);
   }
-  if (auto dstStyleProducer =
-          fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
+  for (auto tileAndFusedOp : tileAndFusedOps) {
+    auto dstStyleProducer =
+        dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
+    if (!dstStyleProducer)
+      continue;
     Value dstValue =
         dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
             ->get();