[mlir][Linalg] Apply fixes to TileReductionUsingForeachThreadOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 9 Dec 2022 12:02:28 +0000 (04:02 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 9 Dec 2022 15:51:12 +0000 (07:51 -0800)
In the process, numerous insertion point issues were found and fixed.
RAII on insertion points is now used more dilligently.

Differential Revision: https://reviews.llvm.org/D139714

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

index f7b0c03..f2b3fb7 100644 (file)
@@ -796,7 +796,7 @@ def TileReductionUsingForeachThreadOp :
         scf.foreach_thread.perform_concurrently {
           tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
         }
-      } {thread_dim_mapping = []}
+      } {mapping = []}
       %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
       ^bb0(%in: f32, %out: f32):
         %4 = arith.addf %in, %out : f32
@@ -807,7 +807,8 @@ def TileReductionUsingForeachThreadOp :
 
   let arguments = (ins PDL_Operation:$target,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+                   OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);
index 9e94f10..8fdd6cb 100644 (file)
@@ -1222,7 +1222,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
   FailureOr<linalg::ForeachThreadReductionTilingResult> result =
       linalg::tileReductionUsingForeachThread(
           rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
-          numThreads, tileSizes, /*mapping=*/std::nullopt);
+          numThreads, tileSizes, getMapping());
 
   if (failed(result)) {
     results.assign(3, nullptr);
index 8c34c42..f5cbd81 100644 (file)
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include <utility>
 
@@ -221,6 +224,9 @@ static void calculateTileOffsetsAndSizes(
     Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
     SmallVector<OpFoldResult> &tiledOffsets,
     SmallVector<OpFoldResult> &tiledSizes) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+
   ValueRange threadIds = foreachThreadOp.getThreadIndices();
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
@@ -300,6 +306,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
     Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
   Location loc = op->getLoc();
   OpBuilder::InsertionGuard g(b);
+
   SmallVector<Range> loopRanges = op.getIterationDomain(b);
   if (loopRanges.empty())
     return op->emitOpError("expected non-empty loop ranges");
@@ -323,54 +330,64 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
 
   Operation *tiledOp = nullptr;
 
-  // Create the ForeachThreadOp. We don't use the lambda body-builder
+  // 1. Create the ForeachThreadOp. We don't use the lambda body-builder
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
       loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
 
-  // Fill out the ForeachThreadOp body.
-  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  // 2. Fill out the ForeachThreadOp body.
   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
   calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
                                omitTileOffsetBoundsCheck, nominalTileSizes,
                                tiledOffsets, tiledSizes);
 
-  // Clone the tileable op and update its destination operands to use the output
-  // bbArgs of the ForeachThreadOp.
+  // 3. Clone the tileable op and update its destination operands to use the
+  // output bbArgs of the ForeachThreadOp.
   ArrayRef<BlockArgument> destBbArgs =
       foreachThreadOp.getOutputBlockArguments();
-  Operation *clonedOp = b.clone(*op.getOperation());
-  auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
-  if (destinationStyleOp) {
-    for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
-      auto *it = llvm::find(dest, outOperand->get());
-      assert(it != dest.end() && "dest operand not found in dest");
-      unsigned destNum = std::distance(dest.begin(), it);
-      outOperand->set(destBbArgs[destNum]);
+  {
+    // 3.a. RAII guard, inserting within foreachThreadOp, before terminator.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(foreachThreadOp.getTerminator());
+    Operation *clonedOp = b.clone(*op.getOperation());
+    auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
+    if (destinationStyleOp) {
+      for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
+        auto *it = llvm::find(dest, outOperand->get());
+        assert(it != dest.end() && "dest operand not found in dest");
+        unsigned destNum = std::distance(dest.begin(), it);
+        outOperand->set(destBbArgs[destNum]);
+      }
     }
-  }
 
-  // Tile the cloned op and delete the clone.
-  SmallVector<Operation *> tiledOps =
-      cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                             tiledSizes);
-  b.eraseOp(clonedOp);
-  assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-  tiledOp = tiledOps.front();
+    // 4. Tile the cloned op and delete the clone.
+    SmallVector<Operation *> tiledOps =
+        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+                                                               tiledSizes);
+    b.eraseOp(clonedOp);
+    assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+    tiledOp = tiledOps.front();
+  }
 
+  // 5. Parallel insert back into the result tensor.
   auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
   assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
-  OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
   for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
                            tilingInterfaceOp->getResults(), destBbArgs)) {
-    b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+    // 5.a. Partial subset information is inserted just before the terminator.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(foreachThreadOp.getTerminator());
+
     SmallVector<OpFoldResult> resultOffsets, resultSizes;
     if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
                                         tiledSizes, resultOffsets,
                                         resultSizes)))
       return op->emitOpError("output offsets couldn't be calculated");
     SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
+
+    // 5.b. Parallel insertions are inserted at the end of the combining
+    // terminator.
     b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
     b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
                                             std::get<2>(it), resultOffsets,
@@ -415,6 +432,8 @@ template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
                  const LinalgTilingOptions &options) {
+  OpBuilder::InsertionGuard g(b);
+
   auto nLoops = op.getNumLoops();
   // Initial tile sizes may be too big, only take the first nLoops.
   tileSizes = tileSizes.take_front(nLoops);
@@ -570,17 +589,35 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
                                         Optional<ArrayAttr> mapping) {
   Location loc = op.getLoc();
   OpBuilder::InsertionGuard g(b);
+
   // Ops implementing PartialReductionOpInterface are expected to implement
   // TilingInterface.
+  // TODO: proper core mechanism to tie interfaces together.
   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+
+  // Ops implementing PartialReductionOpInterface are not necessarily expected
+  // to implement TilingInterface.. This cast is unsafe atm.
+  // TODO: proper core mechanism to tie interfaces together.
+  // TODO: this function requires a pair of interfaces ..
+  auto destinationStyleOp =
+      dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+  if (!destinationStyleOp)
+    return b.notifyMatchFailure(op, "not a destination style op");
+
+  // Actually this only work for Linalg ops atm.
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
+  if (!linalgOp)
+    return b.notifyMatchFailure(op, "not a linalg op");
+
   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
   if (op->getNumResults() != 1)
     return b.notifyMatchFailure(
         op, "don't support ops with multiple results for now");
+
   SmallVector<utils::IteratorType> iterators =
       tilingInterfaceOp.getLoopIteratorTypes();
   SmallVector<unsigned> redDims;
-  cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+  linalgOp.getReductionDims(redDims);
   if (redDims.size() != 1)
     return b.notifyMatchFailure(
         op, "only support ops with one reduction dimension.");
@@ -588,7 +625,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
                                     "many elements as number of threads");
   int reductionDim = static_cast<int>(redDims.front());
-  // 1. create the inital tensor value.
+
+  // 1. Create the inital tensor value.
   FailureOr<Operation *> identityTensor =
       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
                                                   reductionDim);
@@ -615,8 +653,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
       loc, identityTensor.value()->getResults(),
       ValueRange(materializedNonZeroNumThreads), mapping);
 
-  // 3. calculate the tile offsets and sizes.
-  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  // 3. Calculate the tile offsets and sizes for the subsequent loop that will
+  // be nested under `foreachThreadOp`.
   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
   calculateTileOffsetsAndSizes(
       b, loc, foreachThreadOp, numThreads, iterationDomain,
@@ -625,54 +663,77 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
 
   // 4. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForeachThreadOp.
+  ValueRange tilingResults;
   ArrayRef<BlockArgument> destBbArgs =
       foreachThreadOp.getOutputBlockArguments();
-  Operation *clonedOp = b.clone(*op.getOperation());
-  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
-  auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
-  for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
-    auto *it = llvm::find(dest, initOperand->get());
-    assert(it != dest.end() && "dest operand not found in dest");
-    unsigned destNum = std::distance(dest.begin(), it);
-    SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
-    SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
-    SmallVector<OpFoldResult> sizes = tiledSizes;
-    sizes[reductionDim] = b.getIndexAttr(1);
-    outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
-    // TODO: use SubsetExtractOpInterface once it is available.
-    Value patial = b.create<tensor::ExtractSliceOp>(
-        loc, initOperand->get().getType().cast<RankedTensorType>(),
-        destBbArgs[destNum], outOffsets, sizes, strides);
-    initOperand->set(patial);
-  }
-  b.setInsertionPoint(clonedOp);
+  {
+    // 4.a. RAII guard, inserting within foreachThreadOp, before terminator.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(foreachThreadOp.getTerminator());
+
+    SmallVector<Value> tiledDpsInitOperands;
+    for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
+      auto *it = llvm::find(dest, initOperand->get());
+      assert(it != dest.end() && "dest operand not found in dest");
+      unsigned destNum = std::distance(dest.begin(), it);
+      SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+      SmallVector<OpFoldResult> outOffsets(numThreads.size(),
+                                           b.getIndexAttr(0));
+      SmallVector<OpFoldResult> sizes = tiledSizes;
+      sizes[reductionDim] = b.getIndexAttr(1);
+      outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+      // TODO: use SubsetExtractOpInterface once it is available.
+      tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
+          loc, initOperand->get().getType().cast<RankedTensorType>(),
+          destBbArgs[destNum], outOffsets, sizes, strides));
+    }
 
-  // 5. Tile the cloned op and delete the clone.
-  if (tileSizes.empty()) {
-    SmallVector<Operation *> tiledOps =
-        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                               tiledSizes);
-    assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-    tiledOp = tiledOps.front();
-  } else {
-    LinalgTilingOptions options;
-    auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
-                                              tileSizes, options);
-    SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
-    mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
-                          materializedNonZeroNumThreads);
-    assert(tiled->loops.size() == 1 && "expected a single produced loop");
-    tiledOp = tiled->loops.front();
+    // 4.b. Clone the op and update init operands.
+    // We cannot use a BlockAndValueMapping here because it can replace
+    // different OpOperands with the same value.
+    Operation *clonedOp = b.clone(*op.getOperation());
+    b.updateRootInPlace(clonedOp, [&]() {
+      for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
+               cast<DestinationStyleOpInterface>(clonedOp).getDpsInitOperands(),
+               tiledDpsInitOperands)) {
+        initOperandPtr->set(tiledInitValue);
+      }
+    });
+
+    // 5. Tile the cloned op and delete the clone.
+    if (tileSizes.empty()) {
+      SmallVector<Operation *> tiledOps =
+          cast<TilingInterface>(clonedOp).getTiledImplementation(
+              b, tiledOffsets, tiledSizes);
+      assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+      tiledOp = tiledOps.front();
+      tilingResults = tiledOp->getResults();
+    } else {
+      LinalgTilingOptions options;
+      FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
+          b, cast<LinalgOp>(clonedOp), tileSizes, options);
+      if (failed(maybeTiled))
+        return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
+
+      SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+      mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
+                            materializedNonZeroNumThreads);
+      assert(maybeTiled->loops.size() == 1 &&
+             "expected a single produced loop");
+      tiledOp = maybeTiled->op;
+      tilingResults = maybeTiled->loops.front()->getResults();
+    }
+
+    b.eraseOp(clonedOp);
   }
-  b.eraseOp(clonedOp);
 
   // 6. Insert the partial reductions back into a new tensor.
-  b.setInsertionPointAfter(tiledOp);
-  OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
-  for (auto [index, result, bbArg] :
-       llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
-                 destBbArgs)) {
-    b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+  for (auto [index, result, bbArg] : llvm::zip(
+           llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
+    // 6.a. Partial subset information is inserted just before the terminator.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(foreachThreadOp.getTerminator());
+
     SmallVector<OpFoldResult> resultOffsets, resultSizes;
     if (failed(tilingInterfaceOp.getResultTilePosition(
             b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
@@ -689,18 +750,23 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
       resultOffsetsRank.push_back(resultOffsets[offIdx++]);
       resultSizesRank.push_back(resultSizes[sizeIdx++]);
     }
-
     SmallVector<OpFoldResult> strides(resultSizesRank.size(),
                                       b.getIndexAttr(1));
+
+    // 6.b. Parallel insertions are inserted at the end of the combining
+    // terminator.
     b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
     b.create<tensor::ParallelInsertSliceOp>(
         loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
   }
+
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(foreachThreadOp);
   Operation *mergeOp =
       op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
   b.replaceOp(op, mergeOp->getResults());
+
+  // 8. Return.
   ForeachThreadReductionTilingResult results;
   results.initialOp = identityTensor.value();
   results.loops = foreachThreadOp;
index f7f0621..76f6485 100644 (file)
@@ -874,19 +874,19 @@ void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
 DiagnosedSilenceableFailure
 transform::PrintOp::apply(transform::TransformResults &results,
                           transform::TransformState &state) {
-  llvm::errs() << "[[[ IR printer: ";
+  llvm::outs() << "[[[ IR printer: ";
   if (getName().has_value())
-    llvm::errs() << *getName() << " ";
+    llvm::outs() << *getName() << " ";
 
   if (!getTarget()) {
-    llvm::errs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+    llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
     return DiagnosedSilenceableFailure::success();
   }
 
-  llvm::errs() << "]]]\n";
+  llvm::outs() << "]]]\n";
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
   for (Operation *target : targets)
-    llvm::errs() << *target << "\n";
+    llvm::outs() << *target << "\n";
 
   return DiagnosedSilenceableFailure::success();
 }
index ad2dc0a..cd0d6d7 100644 (file)
@@ -218,7 +218,8 @@ func.func @reduction_tile_parallel_cyclic_dist(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] }
+  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
+    { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
@@ -262,3 +263,39 @@ transform.sequence failures(propagate) {
 //     CHECK:     linalg.yield
 //     CHECK:   } -> tensor<?xf32>
 //     CHECK:   return %[[R]] : tensor<?xf32>
+
+// -----
+
+func.func @reduction_tile_parallel_cyclic_dist(
+  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out : tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %1 = arith.mulf %arg7, %arg7 : f32
+      %2 = arith.addf %1, %arg9 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?xf32>
+  return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
+    { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+  
+  //      CHECK:     expecting fill
+  // CHECK-NEXT:     linalg.fill
+  transform.print %1 {name = "expecting fill"} : !pdl.operation
+  //      CHECK:     expecting parallel reduction
+  // CHECK-NEXT:     linalg.generic
+  //      CHECK:     iterator_types = ["parallel", "reduction"]
+  transform.print %2 {name = "expecting parallel reduction"} : !pdl.operation
+  //      CHECK:     expecting parallel reduction
+  // CHECK-NEXT:     linalg.generic
+  //      CHECK:     iterator_types = ["parallel", "reduction"]
+  transform.print %3 {name = "expecting parallel reduction"} : !pdl.operation
+}