regions.push_back(RegionSuccessor(getResults()));
}
+/// Promotes the loop body of a forallOp to its containing block if it can be
+/// determined that the loop has a single iteration.
+LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
+ scf::ForallOp forallOp) {
+ for (auto [lb, ub, step] :
+ llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep())) {
+ auto tripCount = constantTripCount(lb, ub, step);
+ if (!tripCount.has_value() || *tripCount != 1)
+ return failure();
+ }
+
+ promote(rewriter, forallOp);
+ return success();
+}
+
+/// Promotes the loop body of a scf::ForallOp to its containing block.
+void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
+ IRMapping mapping;
+ mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
+ mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
+ for (auto &bodyOp : forallOp.getBody()->without_terminator())
+ rewriter.clone(bodyOp, mapping);
+
+ SmallVector<Value> results;
+ results.reserve(forallOp.getResults().size());
+ scf::InParallelOp terminator = forallOp.getTerminator();
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelInsertSliceOp =
+ cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+
+ Value dst = parallelInsertSliceOp.getDest();
+ Value src = parallelInsertSliceOp.getSource();
+
+ auto getMappedValues = [&](ValueRange values) {
+ return llvm::to_vector(llvm::map_range(
+ values, [&](Value value) { return mapping.lookupOrDefault(value); }));
+ };
+
+ Value srcVal = mapping.lookupOrDefault(src);
+ if (srcVal.getType().isa<TensorType>()) {
+ results.push_back(rewriter.create<tensor::InsertSliceOp>(
+ forallOp.getLoc(), dst.getType(), srcVal,
+ mapping.lookupOrDefault(dst),
+ getMappedValues(parallelInsertSliceOp.getOffsets()),
+ getMappedValues(parallelInsertSliceOp.getSizes()),
+ getMappedValues(parallelInsertSliceOp.getStrides()),
+ parallelInsertSliceOp.getStaticOffsets(),
+ parallelInsertSliceOp.getStaticSizes(),
+ parallelInsertSliceOp.getStaticStrides()));
+ }
+ }
+ rewriter.replaceOp(forallOp, results);
+}
+
LoopNest mlir::scf::buildLoopNest(
OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
ValueRange steps, ValueRange iterArgs,
dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
op.getDynamicStepMutable().assign(dynamicStep);
op.setStaticStep(staticStep);
+
+ op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(dynamicLowerBound.size()),
+ static_cast<int32_t>(dynamicUpperBound.size()),
+ static_cast<int32_t>(dynamicStep.size()),
+ static_cast<int32_t>(op.getNumResults())}));
});
return success();
}
};
+struct ForallOpSingleOrZeroIterationDimsFolder
+ : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ // Do not fold dimensions if they are mapped to processing units.
+ if (op.getMapping().has_value())
+ return failure();
+ Location loc = op.getLoc();
+
+ // Compute new loop bounds that omit all single-iteration loop dimensions.
+ SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
+ newMixedSteps;
+ IRMapping mapping;
+ for (auto [lb, ub, step, iv] :
+ llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+ op.getMixedStep(), op.getInductionVars())) {
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (numIterations.has_value()) {
+ // Remove the loop if it performs zero iterations.
+ if (*numIterations == 0) {
+ rewriter.replaceOp(op, op.getOutputs());
+ return success();
+ }
+ // Replace the loop induction variable by the lower bound if the loop
+ // performs a single iteration. Otherwise, copy the loop bounds.
+ if (*numIterations == 1) {
+ mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ continue;
+ }
+ }
+ newMixedLowerBounds.push_back(lb);
+ newMixedUpperBounds.push_back(ub);
+ newMixedSteps.push_back(step);
+ }
+ // Exit if none of the loop dimensions perform a single iteration.
+ if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+ return rewriter.notifyMatchFailure(
+ op, "no dimensions have 0 or 1 iterations");
+ }
+
+ // All of the loop dimensions perform a single iteration. Inline loop body.
+ if (newMixedLowerBounds.empty()) {
+ promote(rewriter, op);
+ return success();
+ }
+
+ // Replace the loop by a lower-dimensional loop.
+ ForallOp newOp;
+ newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
+ newMixedUpperBounds, newMixedSteps,
+ op.getOutputs(), std::nullopt, nullptr);
+ newOp.getBodyRegion().getBlocks().clear();
+ // The new loop needs to keep all attributes from the old one, except for
+ // "operand_segment_sizes" and static loop bound attributes which capture
+ // the outdated information of the old iteration domain.
+ SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
+ newOp.getStaticLowerBoundAttrName(),
+ newOp.getStaticUpperBoundAttrName(),
+ newOp.getStaticStepAttrName()};
+ for (const auto &namedAttr : op->getAttrs()) {
+ if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
+ continue;
+ rewriter.updateRootInPlace(newOp, [&]() {
+ newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+ });
+ }
+ rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().begin(), mapping);
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+
} // namespace
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
+ results.add<DimOfForallOp, ForallOpControlOperandsFolder,
+ ForallOpSingleOrZeroIterationDimsFolder>(context);
}
//===----------------------------------------------------------------------===//
namespace {
// Collapse loop dimensions that perform a single iteration.
-struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
+struct ParallelOpSingleOrZeroIterationDimsFolder
+ : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ParallelOp op,
PatternRewriter &rewriter) const override {
- IRMapping mapping;
+ Location loc = op.getLoc();
+
// Compute new loop bounds that omit all single-iteration loop dimensions.
- SmallVector<Value, 2> newLowerBounds;
- SmallVector<Value, 2> newUpperBounds;
- SmallVector<Value, 2> newSteps;
- newLowerBounds.reserve(op.getLowerBound().size());
- newUpperBounds.reserve(op.getUpperBound().size());
- newSteps.reserve(op.getStep().size());
- for (auto [lowerBound, upperBound, step, iv] :
+ SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
+ IRMapping mapping;
+ for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
- // Collect the statically known loop bounds.
- auto lowerBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
- auto upperBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
- auto stepConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
- // Replace the loop induction variable by the lower bound if the loop
- // performs a single iteration. Otherwise, copy the loop bounds.
- if (lowerBoundConstant && upperBoundConstant && stepConstant &&
- (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
- (upperBoundConstant.value() - lowerBoundConstant.value()) <=
- stepConstant.value()) {
- mapping.map(iv, lowerBound);
- } else {
- newLowerBounds.push_back(lowerBound);
- newUpperBounds.push_back(upperBound);
- newSteps.push_back(step);
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (numIterations.has_value()) {
+ // Remove the loop if it performs zero iterations.
+ if (*numIterations == 0) {
+ rewriter.replaceOp(op, op.getInitVals());
+ return success();
+ }
+ // Replace the loop induction variable by the lower bound if the loop
+ // performs a single iteration. Otherwise, copy the loop bounds.
+ if (*numIterations == 1) {
+ mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ continue;
+ }
}
+ newLowerBounds.push_back(lb);
+ newUpperBounds.push_back(ub);
+ newSteps.push_back(step);
}
// Exit if none of the loop dimensions perform a single iteration.
if (newLowerBounds.size() == op.getLowerBound().size())
}
};
-/// Removes parallel loops in which at least one lower/upper bound pair consists
-/// of the same values - such loops have an empty iteration domain.
-struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
- using OpRewritePattern<ParallelOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ParallelOp op,
- PatternRewriter &rewriter) const override {
- for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
- if (std::get<0>(dim) == std::get<1>(dim)) {
- rewriter.replaceOp(op, op.getInitVals());
- return success();
- }
- }
- return failure();
- }
-};
-
struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
- MergeNestedParallelLoops>(context);
+ results
+ .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
+ context);
}
//===----------------------------------------------------------------------===//
return %result : tensor<?x10xf32>
}
// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)
+
+// -----
+
+func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1)
+ step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<8x8xf32> to tensor<2x3xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+ -> tensor<2x3xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<2x3xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @inline_forall_loop
+// CHECK-NOT: scf.forall
+// CHECK: %[[OUT:.*]] = tensor.empty
+
+// CHECK-NEXT: %[[SLICE:.*]] = tensor.extract_slice %[[OUT]]
+// CHECK-SAME: : tensor<8x8xf32> to tensor<2x3xf32>
+
+// CHECK-NEXT: %[[FILL:.*]] = linalg.fill
+// CHECK-SAME: outs(%[[SLICE]]
+
+// CHECK-NEXT: tensor.insert_slice %[[FILL]]
+// CHECK-SAME: : tensor<2x3xf32> into tensor<8x8xf32>
+
+// -----
+
+func.func @do_not_inline_distributed_forall_loop(
+ %in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
+ shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<8x8xf32> to tensor<2x3xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+ -> tensor<2x3xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<2x3xf32> into tensor<8x8xf32>
+ }
+ }{ mapping = [#gpu.thread<y>, #gpu.thread<x>] }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @do_not_inline_distributed_forall_loop
+// CHECK: scf.forall
+
+// -----
+
+func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16)
+ step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+ : tensor<8x8xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @collapse_one_dim_parallel
+// CHECK: scf.forall (%[[ARG:.*]]) = (0) to (16) step (8)
+// CHECK: linalg.fill
+// CHECK: tensor.parallel_insert_slice
+
+// -----
+
+func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16)
+ step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+ : tensor<8x8xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @remove_empty_forall
+// CHECK-NOT: scf.forall
+// CHECK: %[[EMPTY:.*]] = tensor.empty
+// CHECK: return %[[EMPTY]]
+