[mlir] Add a pattern to fold single- and zero-iteration scf.forall ops.
authorAlexander Belyaev <pifon@google.com>
Tue, 21 Mar 2023 10:56:08 +0000 (11:56 +0100)
committerAlexander Belyaev <pifon@google.com>
Tue, 21 Mar 2023 10:59:25 +0000 (11:59 +0100)
Differential Revision: https://reviews.llvm.org/D145368

mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

index 7f714d0..cb399b7 100644 (file)
@@ -62,6 +62,14 @@ ForallOp getForallOpThreadIndexOwner(Value val);
 // TODO: Consider moving this functionality to RegionBranchOpInterface.
 bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
 
+/// Promotes the loop body of a scf::ForallOp to its containing block if the
+/// loop was known to have a single iteration.
+LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter,
+                                       scf::ForallOp forallOp);
+
+/// Promotes the loop body of a scf::ForallOp to its containing block.
+void promote(PatternRewriter &rewriter, scf::ForallOp forallOp);
+
 /// An owning vector of values, handy to return from functions.
 using ValueVector = SmallVector<Value>;
 using LoopVector = SmallVector<scf::ForOp>;
index 27c2775..47910e2 100644 (file)
@@ -128,6 +128,11 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Return the number of iterations for a loop with a lower bound `lb`, upper
+/// bound `ub` and step `step`.
+std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
+                                         OpFoldResult step);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
index 4e7bcc4..e212159 100644 (file)
@@ -534,6 +534,61 @@ void ForOp::getSuccessorRegions(std::optional<unsigned> index,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+/// Promotes the loop body of a forallOp to its containing block if it can be
+/// determined that the loop has a single iteration.
+LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
+                                                  scf::ForallOp forallOp) {
+  for (auto [lb, ub, step] :
+       llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+                 forallOp.getMixedStep())) {
+    auto tripCount = constantTripCount(lb, ub, step);
+    if (!tripCount.has_value() || *tripCount != 1)
+      return failure();
+  }
+
+  promote(rewriter, forallOp);
+  return success();
+}
+
+/// Promotes the loop body of a scf::ForallOp to its containing block.
+void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
+  IRMapping mapping;
+  mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
+  mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
+  for (auto &bodyOp : forallOp.getBody()->without_terminator())
+    rewriter.clone(bodyOp, mapping);
+
+  SmallVector<Value> results;
+  results.reserve(forallOp.getResults().size());
+  scf::InParallelOp terminator = forallOp.getTerminator();
+  for (auto &yieldingOp : terminator.getYieldingOps()) {
+    auto parallelInsertSliceOp =
+        cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+
+    Value dst = parallelInsertSliceOp.getDest();
+    Value src = parallelInsertSliceOp.getSource();
+
+    auto getMappedValues = [&](ValueRange values) {
+      return llvm::to_vector(llvm::map_range(
+          values, [&](Value value) { return mapping.lookupOrDefault(value); }));
+    };
+
+    Value srcVal = mapping.lookupOrDefault(src);
+    if (srcVal.getType().isa<TensorType>()) {
+      results.push_back(rewriter.create<tensor::InsertSliceOp>(
+          forallOp.getLoc(), dst.getType(), srcVal,
+          mapping.lookupOrDefault(dst),
+          getMappedValues(parallelInsertSliceOp.getOffsets()),
+          getMappedValues(parallelInsertSliceOp.getSizes()),
+          getMappedValues(parallelInsertSliceOp.getStrides()),
+          parallelInsertSliceOp.getStaticOffsets(),
+          parallelInsertSliceOp.getStaticSizes(),
+          parallelInsertSliceOp.getStaticStrides()));
+    }
+  }
+  rewriter.replaceOp(forallOp, results);
+}
+
 LoopNest mlir::scf::buildLoopNest(
     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
     ValueRange steps, ValueRange iterArgs,
@@ -1452,16 +1507,99 @@ public:
       dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
       op.getDynamicStepMutable().assign(dynamicStep);
       op.setStaticStep(staticStep);
+
+      op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
+                  rewriter.getDenseI32ArrayAttr(
+                      {static_cast<int32_t>(dynamicLowerBound.size()),
+                       static_cast<int32_t>(dynamicUpperBound.size()),
+                       static_cast<int32_t>(dynamicStep.size()),
+                       static_cast<int32_t>(op.getNumResults())}));
     });
     return success();
   }
 };
 
+struct ForallOpSingleOrZeroIterationDimsFolder
+    : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp op,
+                                PatternRewriter &rewriter) const override {
+    // Do not fold dimensions if they are mapped to processing units.
+    if (op.getMapping().has_value())
+      return failure();
+    Location loc = op.getLoc();
+
+    // Compute new loop bounds that omit all single-iteration loop dimensions.
+    SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
+        newMixedSteps;
+    IRMapping mapping;
+    for (auto [lb, ub, step, iv] :
+         llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+                   op.getMixedStep(), op.getInductionVars())) {
+      auto numIterations = constantTripCount(lb, ub, step);
+      if (numIterations.has_value()) {
+        // Remove the loop if it performs zero iterations.
+        if (*numIterations == 0) {
+          rewriter.replaceOp(op, op.getOutputs());
+          return success();
+        }
+        // Replace the loop induction variable by the lower bound if the loop
+        // performs a single iteration. Otherwise, copy the loop bounds.
+        if (*numIterations == 1) {
+          mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+          continue;
+        }
+      }
+      newMixedLowerBounds.push_back(lb);
+      newMixedUpperBounds.push_back(ub);
+      newMixedSteps.push_back(step);
+    }
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+      return rewriter.notifyMatchFailure(
+          op, "no dimensions have 0 or 1 iterations");
+    }
+
+    // All of the loop dimensions perform a single iteration. Inline loop body.
+    if (newMixedLowerBounds.empty()) {
+      promote(rewriter, op);
+      return success();
+    }
+
+    // Replace the loop by a lower-dimensional loop.
+    ForallOp newOp;
+    newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
+                                      newMixedUpperBounds, newMixedSteps,
+                                      op.getOutputs(), std::nullopt, nullptr);
+    newOp.getBodyRegion().getBlocks().clear();
+    // The new loop needs to keep all attributes from the old one, except for
+    // "operand_segment_sizes" and static loop bound attributes which capture
+    // the outdated information of the old iteration domain.
+    SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
+                                        newOp.getStaticLowerBoundAttrName(),
+                                        newOp.getStaticUpperBoundAttrName(),
+                                        newOp.getStaticStepAttrName()};
+    for (const auto &namedAttr : op->getAttrs()) {
+      if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
+        continue;
+      rewriter.updateRootInPlace(newOp, [&]() {
+        newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+      });
+    }
+    rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
+                               newOp.getRegion().begin(), mapping);
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+
 } // namespace
 
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
+  results.add<DimOfForallOp, ForallOpControlOperandsFolder,
+              ForallOpSingleOrZeroIterationDimsFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2615,41 +2753,37 @@ ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
 
 namespace {
 // Collapse loop dimensions that perform a single iteration.
-struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
+struct ParallelOpSingleOrZeroIterationDimsFolder
+    : public OpRewritePattern<ParallelOp> {
   using OpRewritePattern<ParallelOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ParallelOp op,
                                 PatternRewriter &rewriter) const override {
-    IRMapping mapping;
+    Location loc = op.getLoc();
+
     // Compute new loop bounds that omit all single-iteration loop dimensions.
-    SmallVector<Value, 2> newLowerBounds;
-    SmallVector<Value, 2> newUpperBounds;
-    SmallVector<Value, 2> newSteps;
-    newLowerBounds.reserve(op.getLowerBound().size());
-    newUpperBounds.reserve(op.getUpperBound().size());
-    newSteps.reserve(op.getStep().size());
-    for (auto [lowerBound, upperBound, step, iv] :
+    SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
+    IRMapping mapping;
+    for (auto [lb, ub, step, iv] :
          llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
                    op.getInductionVars())) {
-      // Collect the statically known loop bounds.
-      auto lowerBoundConstant =
-          dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
-      auto upperBoundConstant =
-          dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
-      auto stepConstant =
-          dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
-      // Replace the loop induction variable by the lower bound if the loop
-      // performs a single iteration. Otherwise, copy the loop bounds.
-      if (lowerBoundConstant && upperBoundConstant && stepConstant &&
-          (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
-          (upperBoundConstant.value() - lowerBoundConstant.value()) <=
-              stepConstant.value()) {
-        mapping.map(iv, lowerBound);
-      } else {
-        newLowerBounds.push_back(lowerBound);
-        newUpperBounds.push_back(upperBound);
-        newSteps.push_back(step);
+      auto numIterations = constantTripCount(lb, ub, step);
+      if (numIterations.has_value()) {
+        // Remove the loop if it performs zero iterations.
+        if (*numIterations == 0) {
+          rewriter.replaceOp(op, op.getInitVals());
+          return success();
+        }
+        // Replace the loop induction variable by the lower bound if the loop
+        // performs a single iteration. Otherwise, copy the loop bounds.
+        if (*numIterations == 1) {
+          mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+          continue;
+        }
       }
+      newLowerBounds.push_back(lb);
+      newUpperBounds.push_back(ub);
+      newSteps.push_back(step);
     }
     // Exit if none of the loop dimensions perform a single iteration.
     if (newLowerBounds.size() == op.getLowerBound().size())
@@ -2694,23 +2828,6 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
   }
 };
 
-/// Removes parallel loops in which at least one lower/upper bound pair consists
-/// of the same values - such loops have an empty iteration domain.
-struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
-  using OpRewritePattern<ParallelOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ParallelOp op,
-                                PatternRewriter &rewriter) const override {
-    for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
-      if (std::get<0>(dim) == std::get<1>(dim)) {
-        rewriter.replaceOp(op, op.getInitVals());
-        return success();
-      }
-    }
-    return failure();
-  }
-};
-
 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
   using OpRewritePattern<ParallelOp>::OpRewritePattern;
 
@@ -2773,8 +2890,9 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
 
 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
-              MergeNestedParallelLoops>(context);
+  results
+      .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
index 6eca0ef..e16e288 100644 (file)
@@ -381,18 +381,12 @@ static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
-  auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
-  if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
-      ubCstOp.value() < 0 || stepCstOp.value() < 0)
-    return failure();
-  int64_t tripCount =
-      mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
-  if (tripCount != 1)
+  std::optional<int64_t> tripCount = constantTripCount(
+      forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+  if (!tripCount.has_value() || tripCount != 1)
     return failure();
   auto iv = forOp.getInductionVar();
-  iv.replaceAllUsesWith(lbCstOp);
+  iv.replaceAllUsesWith(forOp.getLowerBound());
 
   replaceIterArgsAndYieldResults(forOp);
 
index e646de9..45edd5f 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/APSInt.h"
 
 namespace mlir {
@@ -228,4 +229,24 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
   return getValuesSortedByKeyImpl(keys, values, compare);
 }
 
+/// Return the number of iterations for a loop with a lower bound `lb`, upper
+/// bound `ub` and step `step`.
+std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
+                                         OpFoldResult step) {
+  if (lb == ub)
+    return 0;
+
+  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
+  if (!lbConstant)
+    return std::nullopt;
+  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
+  if (!ubConstant)
+    return std::nullopt;
+  std::optional<int64_t> stepConstant = getConstantIntValue(step);
+  if (!stepConstant)
+    return std::nullopt;
+
+  return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
+}
+
 } // namespace mlir
index a3ce8a6..f69cf19 100644 (file)
@@ -1544,3 +1544,110 @@ func.func @forall_fold_control_operands(
   return %result : tensor<?x10xf32>
 }
 // CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)
+
+// -----
+
+func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<8x8xf32>
+  %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1)
+        step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+    %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+      : tensor<8x8xf32> to tensor<2x3xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+          -> tensor<2x3xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+        : tensor<2x3xf32> into tensor<8x8xf32>
+    }
+  }
+  return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @inline_forall_loop
+// CHECK-NOT:     scf.forall
+// CHECK:         %[[OUT:.*]] = tensor.empty
+
+// CHECK-NEXT:    %[[SLICE:.*]] = tensor.extract_slice %[[OUT]]
+// CHECK-SAME:      : tensor<8x8xf32> to tensor<2x3xf32>
+
+// CHECK-NEXT:    %[[FILL:.*]] = linalg.fill
+// CHECK-SAME:      outs(%[[SLICE]]
+
+// CHECK-NEXT:    tensor.insert_slice %[[FILL]]
+// CHECK-SAME:      : tensor<2x3xf32> into tensor<8x8xf32>
+
+// -----
+
+func.func @do_not_inline_distributed_forall_loop(
+    %in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<8x8xf32>
+  %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
+      shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+    %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+      : tensor<8x8xf32> to tensor<2x3xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+          -> tensor<2x3xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+        : tensor<2x3xf32> into tensor<8x8xf32>
+    }
+  }{ mapping = [#gpu.thread<y>, #gpu.thread<x>] }
+  return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @do_not_inline_distributed_forall_loop
+// CHECK: scf.forall
+
+// -----
+
+func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<8x8xf32>
+  %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16)
+        step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+    %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+          -> tensor<8x8xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+        : tensor<8x8xf32> into tensor<8x8xf32>
+    }
+  }
+  return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @collapse_one_dim_parallel
+// CHECK:         scf.forall (%[[ARG:.*]]) = (0) to (16) step (8)
+// CHECK:           linalg.fill
+// CHECK:           tensor.parallel_insert_slice
+
+// -----
+
+func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<8x8xf32>
+  %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16)
+        step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+    %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+          -> tensor<8x8xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+        : tensor<8x8xf32> into tensor<8x8xf32>
+    }
+  }
+  return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @remove_empty_forall
+// CHECK-NOT:   scf.forall
+// CHECK:       %[[EMPTY:.*]] = tensor.empty
+// CHECK:       return %[[EMPTY]]
+
index 2358dde..750a8d0 100644 (file)
@@ -86,7 +86,7 @@ func.func @insert_slice_rank_reducing_dynamic_shape(
 
 // CHECK-LABEL: func.func @parallel_insert_slice
 //   CHECK-NOT:   tensor.insert_slice
-//       CHECK:   tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
+//       CHECK:   tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
 func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index