[mlir][TilingInterface] NFC: Separate out a utility method to perform one step of...
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 6 Dec 2022 06:24:54 +0000 (06:24 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 16 Jan 2023 05:03:41 +0000 (05:03 +0000)
Differential Revision: https://reviews.llvm.org/D141027

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

index 52cd760..dd0ed44 100644 (file)
@@ -505,6 +505,101 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
 }
 
+static std::optional<Operation *>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                           tensor::ExtractSliceOp candidateSliceOp,
+                           MutableArrayRef<scf::ForOp> loops) {
+  // 1. Get the producer of the source (potentially walking through
+  // `iter_args` of nested `scf.for`)
+  auto [fusableProducer, destinationIterArg] =
+      getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+                                        loops);
+  if (!fusableProducer)
+    return std::nullopt;
+
+  // 2. Generate the tiled implementation of the producer of the source
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+  FailureOr<Value> fusedProducerValue =
+      tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
+                                                   fusableProducer);
+  if (failed(fusedProducerValue))
+    return std::nullopt;
+  rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
+
+  // 3. If the slice is for a destination operand, for example,
+  //
+  // ```mlir
+  // %0 = linalg.init
+  // %1 = linalg.fill .. outs(%0 : )
+  // %2 = scf.for .. iter_args(%arg0 = %1) {
+  //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
+  //     %4 = tensor.extract_slice %arg1 [..]
+  //     .. = linalg.matmul .. outs(%4 : )
+  //   }
+  // }
+  // ```
+  //
+  // the IR is currently
+  //
+  // ```
+  // %0 = linalg.init
+  // %1 = linalg.fill
+  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
+  //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
+  //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+  //     %5 = linalg.fill .. outs(%4 : )
+  //     .. = linalg.matmul .. outs(%5 : )
+  //   }
+  // }
+  // ```
+  //
+  // The untiled `linalg.fill` is still used as the `init_value` since it
+  // was originally a destination operand of the untiled `linalg.matmul`.
+  // When fusing an operand that is a destination operand.
+  //   - Update the iter_arg of the outer most loop to use the destination
+  //     of the untiled producer.
+  //   - Update the destination of the slice of the tiled producer generated
+  //     to use the same basic block argument as the slice that was used to
+  //     generate inplace the tiled implementation of the producer.
+  // With this the IR will be.
+  //
+  // ```
+  // %0 = linalg.init
+  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
+  //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
+  //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+  //     %4 = linalg.fill .. outs(%3 : )
+  //     .. = linalg.matmul .. outs(%4 : )
+  //   }
+  // }
+  // ```
+  // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
+  // Update to use that when it does become available.
+  scf::ForOp outerMostLoop = loops.front();
+  Optional<unsigned> iterArgNumber;
+  if (destinationIterArg) {
+    iterArgNumber =
+        outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
+  }
+  if (iterArgNumber) {
+    int64_t resultNumber = fusableProducer.getResultNumber();
+    if (auto dstOp =
+            dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
+      outerMostLoop.setIterArg(iterArgNumber.value(),
+                               dstOp.getTiedOpOperand(fusableProducer)->get());
+    }
+    if (auto dstOp = fusedProducerValue.value()
+                         .getDefiningOp<DestinationStyleOpInterface>()) {
+      scf::ForOp innerMostLoop = loops.back();
+      updateDestinationOperandsForTiledOp(
+          rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
+          innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+    }
+  }
+  return fusedProducerValue->getDefiningOp();
+}
+
 /// Implementation of tile consumer and fuse producer greedily.
 FailureOr<scf::SCFTileAndFuseResult>
 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
@@ -559,105 +654,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
-    // 2a. Traverse the slices in BFS fashion.
+    // Traverse the slices in BFS fashion.
     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
     candidates.pop_front();
 
-    // 2b. Get the producer of the source (potentially walking through
-    // `iter_args` of nested `scf.for`)
-    auto [fusableProducer, destinationIterArg] =
-        getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
-                                          tileAndFuseResult.loops);
-    if (!fusableProducer)
+    // The operands of the fused producer might themselved be slices of
+    // values produced by operations that implement the `TilingInterface`.
+    // Add these operations to the worklist.
+    Optional<Operation *> fusedProducer = tileAndFuseProducerOfSlice(
+        rewriter, candidateSliceOp, tileAndFuseResult.loops);
+    if (!fusedProducer)
       continue;
 
-    // 2c. Generate the tiled implementation of the producer of the source
-    rewriter.setInsertionPoint(candidateSliceOp);
-    FailureOr<Value> fusedProducerValue =
-        tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
-                                                     fusableProducer);
-    if (failed(fusedProducerValue))
-      continue;
-    rewriter.replaceOp(candidateSliceOp, *fusedProducerValue);
-
-    // 2d. The operands of the fused producer might themselved be slices of
-    //     values produced by operations that implement the `TilingInterface`.
-    //     Add these operations to the worklist.
-    Operation *fusedProducer = fusedProducerValue->getDefiningOp();
-    tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
-    addCandidateSlices(fusedProducer, candidates);
-
-    // 2e. If the slice is for a destination operand, for example,
-    //
-    // ```mlir
-    // %0 = linalg.init
-    // %1 = linalg.fill .. outs(%0 : )
-    // %2 = scf.for .. iter_args(%arg0 = %1) {
-    //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
-    //     %4 = tensor.extract_slice %arg1 [..]
-    //     .. = linalg.matmul .. outs(%4 : )
-    //   }
-    // }
-    // ```
-    //
-    // the IR is currently
-    //
-    // ```
-    // %0 = linalg.init
-    // %1 = linalg.fill
-    // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
-    //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
-    //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
-    //     %5 = linalg.fill .. outs(%4 : )
-    //     .. = linalg.matmul .. outs(%5 : )
-    //   }
-    // }
-    // ```
-    //
-    // The untiled `linalg.fill` is still used as the `init_value` since it
-    // was originally a destination operand of the untiled `linalg.matmul`.
-    // When fusing an operand that is a destination operand.
-    //   - Update the iter_arg of the outer most loop to use the destination
-    //     of the untiled producer.
-    //   - Update the destination of the slice of the tiled producer generated
-    //     to use the same basic block argument as the slice that was used to
-    //     generate inplace the tiled implementation of the producer.
-    // With this the IR will be.
-    //
-    // ```
-    // %0 = linalg.init
-    // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
-    //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
-    //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
-    //     %4 = linalg.fill .. outs(%3 : )
-    //     .. = linalg.matmul .. outs(%4 : )
-    //   }
-    // }
-    // ```
-    // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
-    // Update to use that when it does become available.
-    scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
-    std::optional<unsigned> iterArgNumber;
-    if (destinationIterArg) {
-      iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
-          *destinationIterArg.value());
-    }
-    if (iterArgNumber) {
-      int64_t resultNumber = fusableProducer.getResultNumber();
-      if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
-              fusableProducer.getOwner())) {
-        outerMostLoop.setIterArg(
-            iterArgNumber.value(),
-            dstOp.getTiedOpOperand(fusableProducer)->get());
-      }
-      if (auto dstOp = fusedProducerValue
-                           ->getDefiningOp<DestinationStyleOpInterface>()) {
-        scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
-        updateDestinationOperandsForTiledOp(
-            rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-            innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
-      }
-    }
+    tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value());
+    addCandidateSlices(fusedProducer.value(), candidates);
   }
   return tileAndFuseResult;
 }