From efc0ba0275bd39fe0fb7548345139e9183fce20d Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Tue, 17 Jan 2023 09:33:36 -0500 Subject: [PATCH] [MLIR][Transform] Introduce loop.coalesce transform op. This patch made a minor refactor of LoopCoalescing.cpp's walkLoops templated method and placed it in Affine's LoopUtils.cpp/h. This method is also renamed as coalescePerfectlyNestedLoops method. This minor change enables this method to be invoked by both the original LoopCoalescing pass as well as the newly introduced loop.coalesce transform op. The loop.coalesce transform op has the ability to coalesce affine, and scf loop nests, leveraging existing LoopCoalescing mechanism. I have created it inside the SCFTransformOps.td instead of AffineTransformOps.td as it feels to be similar in spirit as the loop.unroll op that can handle both scf and affine loops. Please let me know if you feel that this op should be moved into AffineTransformOps.td instead. The testcase added illustrates loop.coalesce transform op working for scf, affine loops (inner, outer) as well as coalesced loop can be further unrolled (achieving composibility). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D141202 --- mlir/include/mlir/Dialect/Affine/LoopUtils.h | 49 ++++++++++++ .../Dialect/SCF/TransformOps/SCFTransformOps.td | 27 +++++++ mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 2 +- .../Dialect/Affine/Transforms/LoopCoalescing.cpp | 65 +-------------- mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 1 - .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 29 ++++++- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 5 +- mlir/test/Dialect/SCF/transform-op-coalesce.mlir | 92 ++++++++++++++++++++++ mlir/test/Dialect/SCF/transform-ops-invalid.mlir | 61 ++++++++++++++ mlir/test/Dialect/SCF/transform-ops.mlir | 25 ------ 10 files changed, 262 insertions(+), 94 deletions(-) create mode 100644 mlir/test/Dialect/SCF/transform-op-coalesce.mlir create mode 100644 mlir/test/Dialect/SCF/transform-ops-invalid.mlir diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 2e3a587..f598625 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -18,6 +18,7 @@ #include "mlir/IR/Block.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/RegionUtils.h" #include namespace mlir { @@ -293,6 +294,54 @@ LogicalResult separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); +/// Walk either an scf.for or an affine.for to find a band to coalesce. +template +LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) { + LogicalResult result(failure()); + SmallVector loops; + getPerfectlyNestedLoops(loops, op); + + // Look for a band of loops that can be coalesced, i.e. perfectly nested + // loops with bounds defined above some loop. + // 1. For each loop, find above which parent loop its operands are + // defined. + SmallVector operandsDefinedAbove(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + operandsDefinedAbove[i] = i; + for (unsigned j = 0; j < i; ++j) { + if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) { + operandsDefinedAbove[i] = j; + break; + } + } + } + + // 2. Identify bands of loops such that the operands of all of them are + // defined above the first loop in the band. Traverse the nest bottom-up + // so that modifications don't invalidate the inner loops. + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = + *std::max_element(std::next(operandsDefinedAbove.begin(), start), + std::next(operandsDefinedAbove.begin(), end)); + if (maxPos > start) + continue; + assert(maxPos == start && + "expected loop bounds to be known at the start of the band"); + auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start); + if (succeeded(coalesceLoops(band))) + result = success(); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } + return result; +} + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_LOOPUTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index dd7da91..affa9ab 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -189,4 +189,31 @@ def LoopUnrollOp : Op { + let summary = "Coalesces the perfect loop nest enclosed by a given loop"; + let description = [{ + Given a perfect loop nest identified by the outermost loop, + perform loop coalescing in a bottom-up one-by-one manner. + + #### Return modes + + The return handle points to the coalesced loop if coalescing happens, or + the given input loop if coalescing does not happen. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, $transformed)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index a7fada7..4d4baa9 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -121,7 +121,7 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl &dims, /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes /// `loops` contains a list of perfectly nested loops with bounds and steps /// independent of any loop induction variable involved in the nest. -void coalesceLoops(MutableArrayRef loops); +LogicalResult coalesceLoops(MutableArrayRef loops); /// Take the ParallelLoop and for each set of dimension indices, combine them /// into a single dimension. combinedDimensions must contain each index into diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp index c8c8240..1309270 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp @@ -32,72 +32,13 @@ namespace { struct LoopCoalescingPass : public impl::LoopCoalescingBase { - /// Walk either an scf.for or an affine.for to find a band to coalesce. - template - static void walkLoop(LoopOpTy op) { - // Ignore nested loops. - if (op->template getParentOfType()) - return; - - SmallVector loops; - getPerfectlyNestedLoops(loops, op); - LLVM_DEBUG(llvm::dbgs() - << "found a perfect nest of depth " << loops.size() << '\n'); - - // Look for a band of loops that can be coalesced, i.e. perfectly nested - // loops with bounds defined above some loop. - // 1. For each loop, find above which parent loop its operands are - // defined. - SmallVector operandsDefinedAbove(loops.size()); - for (unsigned i = 0, e = loops.size(); i < e; ++i) { - operandsDefinedAbove[i] = i; - for (unsigned j = 0; j < i; ++j) { - if (areValuesDefinedAbove(loops[i].getOperands(), - loops[j].getRegion())) { - operandsDefinedAbove[i] = j; - break; - } - } - LLVM_DEBUG(llvm::dbgs() - << " bounds of loop " << i << " are known above depth " - << operandsDefinedAbove[i] << '\n'); - } - - // 2. Identify bands of loops such that the operands of all of them are - // defined above the first loop in the band. Traverse the nest bottom-up - // so that modifications don't invalidate the inner loops. - for (unsigned end = loops.size(); end > 0; --end) { - unsigned start = 0; - for (; start < end - 1; ++start) { - auto maxPos = - *std::max_element(std::next(operandsDefinedAbove.begin(), start), - std::next(operandsDefinedAbove.begin(), end)); - if (maxPos > start) - continue; - - assert(maxPos == start && - "expected loop bounds to be known at the start of the band"); - LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start - << " to " << end << '\n'); - - auto band = llvm::MutableArrayRef(loops.data() + start, end - start); - (void)coalesceLoops(band); - break; - } - // If a band was found and transformed, keep looking at the loops above - // the outermost transformed loop. - if (start != end - 1) - end = start + 1; - } - } - void runOnOperation() override { func::FuncOp func = getOperation(); - func.walk([&](Operation *op) { + func.walk([](Operation *op) { if (auto scfForOp = dyn_cast(op)) - walkLoop(scfForOp); + (void)coalescePerfectlyNestedLoops(scfForOp); else if (auto affineForOp = dyn_cast(op)) - walkLoop(affineForOp); + (void)coalescePerfectlyNestedLoops(affineForOp); }); } }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 5860086..ec85e56 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 1ee5a02..5477af7 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -219,9 +219,32 @@ transform::LoopUnrollOp::applyToOne(Operation *op, result = loopUnrollByFactor(affineFor, getFactor()); if (failed(result)) { - Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); - diag << "Op failed to unroll"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to unroll"; + return diag; + } + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LoopCoalesceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::LoopCoalesceOp::applyToOne(Operation *op, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + LogicalResult result(failure()); + if (scf::ForOp scfForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(scfForOp); + else if (AffineForOp affineForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(affineForOp); + + results.push_back(op); + if (failed(result)) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to coalesce"; + return diag; } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index b4c60b6..6eca0ef 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -656,9 +656,9 @@ static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) { loop.setStep(loopPieces.step); } -void mlir::coalesceLoops(MutableArrayRef loops) { +LogicalResult mlir::coalesceLoops(MutableArrayRef loops) { if (loops.size() < 2) - return; + return failure(); scf::ForOp innermost = loops.back(); scf::ForOp outermost = loops.front(); @@ -710,6 +710,7 @@ void mlir::coalesceLoops(MutableArrayRef loops) { Block::iterator(second.getOperation()), innermost.getBody()->getOperations()); second.erase(); + return success(); } void mlir::collapseParallelLoops( diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir new file mode 100644 index 0000000..4c84f62 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +func.func @coalesce_inner() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: scf.for %[[IV0:.+]] + // CHECK: scf.for %[[IV1:.+]] + // CHECK: scf.for %[[IV2:.+]] + // CHECK-NOT: scf.for %[[IV3:.+]] + scf.for %i = %c0 to %c10 step %c1 { + scf.for %j = %c0 to %c10 step %c1 { + scf.for %k = %i to %j step %c1 { + // Inner loop must have been removed. + scf.for %l = %i to %j step %c1 { + arith.addi %i, %j : index + } + } {coalesce} + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) +} + +// ----- + +func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] { + // CHECK-NOT: affine.for %[[IV2:.+]] + affine.for %arg4 = 0 to 64 { + affine.for %arg5 = 0 to 64 { + // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}] + // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}] + // CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"affine.for">) -> (!transform.op<"affine.for">) +} + +// ----- + +func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: scf.for %[[IV1:.+]] = + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.for %arg4 = %c0 to %c64 step %c1 { + // CHECK-NOT: scf.for + scf.for %arg5 = %c0 to %c64 step %c1 { + // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]] + // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]] + // CHECK-NEXT: %{{.+}} = memref.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + // CHECK: memref.store + // CHECK: memref.store + // CHECK: memref.store + // Residual loop must have a single store. + // CHECK: memref.store + memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) + transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for"> +} diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir new file mode 100644 index 0000000..57812de --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +#map0 = affine_map<(d0) -> (d0 * 110)> +#map1 = affine_map<(d0) -> (696, d0 * 110 + 110)> +func.func @test_loops_do_not_get_coalesced() { + affine.for %i = 0 to 7 { + affine.for %j = #map0(%i) to min #map1(%i) { + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + // expected-error @below {{failed to coalesce}} + %2 = transform.loop.coalesce %1: (!transform.op<"affine.for">) -> (!transform.op<"affine.for">) +} + +// ----- + +func.func @test_loops_do_not_get_unrolled() { + affine.for %i = 0 to 7 { + arith.addi %i, %i : index + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + // expected-error @below {{failed to unroll}} + transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for"> +} + +// ----- + +func.func private @cond() -> i1 +func.func private @body() + +func.func @loop_outline_op_multi_region() { + // expected-note @below {{target op}} + scf.while : () -> () { + %0 = func.call @cond() : () -> i1 + scf.condition(%0) + } do { + ^bb0: + func.call @body() : () -> () + scf.yield + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.while"]} in %arg1 + // expected-error @below {{failed to outline}} + transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index d6ff2f2..0e4b384 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -84,31 +84,6 @@ transform.sequence failures(propagate) { // ----- -func.func private @cond() -> i1 -func.func private @body() - -func.func @loop_outline_op_multi_region() { - // expected-note @below {{target op}} - scf.while : () -> () { - %0 = func.call @cond() : () -> i1 - scf.condition(%0) - } do { - ^bb0: - func.call @body() : () -> () - scf.yield - } - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["scf.while"]} in %arg1 - // expected-error @below {{failed to outline}} - transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation -} - -// ----- - // CHECK-LABEL: @loop_peel_op func.func @loop_peel_op() { // CHECK: %[[C0:.+]] = arith.constant 0 -- 2.7.4