[MLIR][Transform] Introduce loop.coalesce transform op.
authorAmy Wang <kai.ting.wang@huawei.com>
Tue, 17 Jan 2023 14:33:36 +0000 (09:33 -0500)
committerPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Tue, 17 Jan 2023 14:38:47 +0000 (09:38 -0500)
This patch made a minor refactor of LoopCoalescing.cpp's walkLoops
templated method and placed it in Affine's LoopUtils.cpp/h.
This method is also renamed as coalescePerfectlyNestedLoops method. This
minor change enables this method to be invoked
by both the original LoopCoalescing pass as well as the newly introduced
loop.coalesce transform op.

The loop.coalesce transform op has the ability to coalesce affine, and
scf loop nests, leveraging existing LoopCoalescing
mechanism. I have created it inside the SCFTransformOps.td instead of
AffineTransformOps.td as it feels to be similar
in spirit as the loop.unroll op that can handle both scf and affine
loops. Please let me know if you feel that this op
should be moved into AffineTransformOps.td instead.

The testcase added illustrates loop.coalesce transform op working for
scf, affine loops (inner, outer) as well as
coalesced loop can be further unrolled (achieving composibility).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D141202

mlir/include/mlir/Dialect/Affine/LoopUtils.h
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/transform-op-coalesce.mlir [new file with mode: 0644]
mlir/test/Dialect/SCF/transform-ops-invalid.mlir [new file with mode: 0644]
mlir/test/Dialect/SCF/transform-ops.mlir

index 2e3a587..f598625 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 
 namespace mlir {
@@ -293,6 +294,54 @@ LogicalResult
 separateFullTiles(MutableArrayRef<AffineForOp> nest,
                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
 
+/// Walk either an scf.for or an affine.for to find a band to coalesce.
+template <typename LoopOpTy>
+LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
+  LogicalResult result(failure());
+  SmallVector<LoopOpTy> loops;
+  getPerfectlyNestedLoops(loops, op);
+
+  // Look for a band of loops that can be coalesced, i.e. perfectly nested
+  // loops with bounds defined above some loop.
+  // 1. For each loop, find above which parent loop its operands are
+  // defined.
+  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+    operandsDefinedAbove[i] = i;
+    for (unsigned j = 0; j < i; ++j) {
+      if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+        operandsDefinedAbove[i] = j;
+        break;
+      }
+    }
+  }
+
+  // 2. Identify bands of loops such that the operands of all of them are
+  // defined above the first loop in the band.  Traverse the nest bottom-up
+  // so that modifications don't invalidate the inner loops.
+  for (unsigned end = loops.size(); end > 0; --end) {
+    unsigned start = 0;
+    for (; start < end - 1; ++start) {
+      auto maxPos =
+          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+                            std::next(operandsDefinedAbove.begin(), end));
+      if (maxPos > start)
+        continue;
+      assert(maxPos == start &&
+             "expected loop bounds to be known at the start of the band");
+      auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start);
+      if (succeeded(coalesceLoops(band)))
+        result = success();
+      break;
+    }
+    // If a band was found and transformed, keep looking at the loops above
+    // the outermost transformed loop.
+    if (start != end - 1)
+      end = start + 1;
+  }
+  return result;
+}
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AFFINE_LOOPUTILS_H
index dd7da91..affa9ab 100644 (file)
@@ -189,4 +189,31 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
   }];
 }
 
+def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
+  FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+  TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Coalesces the perfect loop nest enclosed by a given loop";
+  let description = [{
+    Given a perfect loop nest identified by the outermost loop,
+    perform loop coalescing in a bottom-up one-by-one manner.
+
+    #### Return modes
+
+    The return handle points to the coalesced loop if coalescing happens, or
+    the given input loop if coalescing does not happen.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+      "$target attr-dict `:` functional-type($target, $transformed)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // SCF_TRANSFORM_OPS
index a7fada7..4d4baa9 100644 (file)
@@ -121,7 +121,7 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
 /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes
 /// `loops` contains a list of perfectly nested loops with bounds and steps
 /// independent of any loop induction variable involved in the nest.
-void coalesceLoops(MutableArrayRef<scf::ForOp> loops);
+LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
 
 /// Take the ParallelLoop and for each set of dimension indices, combine them
 /// into a single dimension. combinedDimensions must contain each index into
index c8c8240..1309270 100644 (file)
@@ -32,72 +32,13 @@ namespace {
 struct LoopCoalescingPass
     : public impl::LoopCoalescingBase<LoopCoalescingPass> {
 
-  /// Walk either an scf.for or an affine.for to find a band to coalesce.
-  template <typename LoopOpTy>
-  static void walkLoop(LoopOpTy op) {
-    // Ignore nested loops.
-    if (op->template getParentOfType<LoopOpTy>())
-      return;
-
-    SmallVector<LoopOpTy, 4> loops;
-    getPerfectlyNestedLoops(loops, op);
-    LLVM_DEBUG(llvm::dbgs()
-               << "found a perfect nest of depth " << loops.size() << '\n');
-
-    // Look for a band of loops that can be coalesced, i.e. perfectly nested
-    // loops with bounds defined above some loop.
-    // 1. For each loop, find above which parent loop its operands are
-    // defined.
-    SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
-    for (unsigned i = 0, e = loops.size(); i < e; ++i) {
-      operandsDefinedAbove[i] = i;
-      for (unsigned j = 0; j < i; ++j) {
-        if (areValuesDefinedAbove(loops[i].getOperands(),
-                                  loops[j].getRegion())) {
-          operandsDefinedAbove[i] = j;
-          break;
-        }
-      }
-      LLVM_DEBUG(llvm::dbgs()
-                 << "  bounds of loop " << i << " are known above depth "
-                 << operandsDefinedAbove[i] << '\n');
-    }
-
-    // 2. Identify bands of loops such that the operands of all of them are
-    // defined above the first loop in the band.  Traverse the nest bottom-up
-    // so that modifications don't invalidate the inner loops.
-    for (unsigned end = loops.size(); end > 0; --end) {
-      unsigned start = 0;
-      for (; start < end - 1; ++start) {
-        auto maxPos =
-            *std::max_element(std::next(operandsDefinedAbove.begin(), start),
-                              std::next(operandsDefinedAbove.begin(), end));
-        if (maxPos > start)
-          continue;
-
-        assert(maxPos == start &&
-               "expected loop bounds to be known at the start of the band");
-        LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
-                                << " to " << end << '\n');
-
-        auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
-        (void)coalesceLoops(band);
-        break;
-      }
-      // If a band was found and transformed, keep looking at the loops above
-      // the outermost transformed loop.
-      if (start != end - 1)
-        end = start + 1;
-    }
-  }
-
   void runOnOperation() override {
     func::FuncOp func = getOperation();
-    func.walk([&](Operation *op) {
+    func.walk<WalkOrder::PreOrder>([](Operation *op) {
       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
-        walkLoop(scfForOp);
+        (void)coalescePerfectlyNestedLoops(scfForOp);
       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
-        walkLoop(affineForOp);
+        (void)coalescePerfectlyNestedLoops(affineForOp);
     });
   }
 };
index 5860086..ec85e56 100644 (file)
@@ -25,7 +25,6 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
index 1ee5a02..5477af7 100644 (file)
@@ -219,9 +219,32 @@ transform::LoopUnrollOp::applyToOne(Operation *op,
     result = loopUnrollByFactor(affineFor, getFactor());
 
   if (failed(result)) {
-    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note);
-    diag << "Op failed to unroll";
-    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopCoalesceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::LoopCoalesceOp::applyToOne(Operation *op,
+                                      transform::ApplyToEachResultList &results,
+                                      transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
+    result = coalescePerfectlyNestedLoops(scfForOp);
+  else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
+    result = coalescePerfectlyNestedLoops(affineForOp);
+
+  results.push_back(op);
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to coalesce";
+    return diag;
   }
   return DiagnosedSilenceableFailure::success();
 }
index b4c60b6..6eca0ef 100644 (file)
@@ -656,9 +656,9 @@ static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
   loop.setStep(loopPieces.step);
 }
 
-void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
   if (loops.size() < 2)
-    return;
+    return failure();
 
   scf::ForOp innermost = loops.back();
   scf::ForOp outermost = loops.front();
@@ -710,6 +710,7 @@ void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
       Block::iterator(second.getOperation()),
       innermost.getBody()->getOperations());
   second.erase();
+  return success();
 }
 
 void mlir::collapseParallelLoops(
diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
new file mode 100644 (file)
index 0000000..4c84f62
--- /dev/null
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @coalesce_inner() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  // CHECK: scf.for %[[IV0:.+]]
+  // CHECK:   scf.for %[[IV1:.+]]
+  // CHECK:     scf.for %[[IV2:.+]]
+  // CHECK-NOT:   scf.for %[[IV3:.+]]
+  scf.for %i = %c0 to %c10 step %c1 {
+    scf.for %j = %c0 to %c10 step %c1 {
+      scf.for %k = %i to %j step %c1 {
+        // Inner loop must have been removed.
+        scf.for %l = %i to %j step %c1 {
+          arith.addi %i, %j : index
+        }
+      } {coalesce}
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1
+  %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for">
+  %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+}
+
+// -----
+
+func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+  // CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] {
+  // CHECK-NOT: affine.for %[[IV2:.+]]
+  affine.for %arg4 = 0 to 64 {
+    affine.for %arg5 = 0 to 64 {
+      // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}]
+      // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}]
+      // CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
+      %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
+      %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
+      %2 = arith.addf %0, %1 : f32
+      affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1>
+    }
+  } {coalesce}
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1
+  %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for">
+  %2 = transform.loop.coalesce %1 : (!transform.op<"affine.for">) -> (!transform.op<"affine.for">)
+}
+
+// -----
+
+func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+  // CHECK: scf.for %[[IV1:.+]] =
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c64 = arith.constant 64 : index
+
+  scf.for %arg4 = %c0 to %c64 step %c1 {
+    // CHECK-NOT: scf.for
+    scf.for %arg5 = %c0 to %c64 step %c1 {
+      // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]]
+      // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]]
+      // CHECK-NEXT: %{{.+}} = memref.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
+      %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
+      %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
+      %2 = arith.addf %0, %1 : f32
+      // CHECK: memref.store
+      // CHECK: memref.store
+      // CHECK: memref.store
+      // Residual loop must have a single store.
+      // CHECK: memref.store
+      memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1>
+    }
+  } {coalesce}
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1
+  %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for">
+  %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+  transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for">
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
new file mode 100644 (file)
index 0000000..57812de
--- /dev/null
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+
+#map0 = affine_map<(d0) -> (d0 * 110)>
+#map1 = affine_map<(d0) -> (696, d0 * 110 + 110)>
+func.func @test_loops_do_not_get_coalesced() {
+  affine.for %i = 0 to 7 {
+    affine.for %j = #map0(%i) to min #map1(%i) {
+    }
+  } {coalesce}
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1
+  %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for">
+  // expected-error @below {{failed to coalesce}}
+  %2 = transform.loop.coalesce %1: (!transform.op<"affine.for">) -> (!transform.op<"affine.for">)
+}
+
+// -----
+
+func.func @test_loops_do_not_get_unrolled() {
+  affine.for %i = 0 to 7 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  // expected-error @below {{failed to unroll}}
+  transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func private @cond() -> i1
+func.func private @body()
+
+func.func @loop_outline_op_multi_region() {
+  // expected-note @below {{target op}}
+  scf.while : () -> () {
+    %0 = func.call @cond() : () -> i1
+    scf.condition(%0)
+  } do {
+  ^bb0:
+    func.call @body() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["scf.while"]} in %arg1
+  // expected-error @below {{failed to outline}}
+  transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
+}
index d6ff2f2..0e4b384 100644 (file)
@@ -84,31 +84,6 @@ transform.sequence failures(propagate) {
 
 // -----
 
-func.func private @cond() -> i1
-func.func private @body()
-
-func.func @loop_outline_op_multi_region() {
-  // expected-note @below {{target op}}
-  scf.while : () -> () {
-    %0 = func.call @cond() : () -> i1
-    scf.condition(%0)
-  } do {
-  ^bb0:
-    func.call @body() : () -> ()
-    scf.yield
-  }
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["scf.while"]} in %arg1
-  // expected-error @below {{failed to outline}}
-  transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
 // CHECK-LABEL: @loop_peel_op
 func.func @loop_peel_op() {
   // CHECK: %[[C0:.+]] = arith.constant 0