[mlir][SCF] Add `loop.promote_if_one_iteration` transform op
authorMatthias Springer <me@m-sp.org>
Tue, 4 Jul 2023 06:57:55 +0000 (08:57 +0200)
committerMatthias Springer <me@m-sp.org>
Tue, 4 Jul 2023 06:58:49 +0000 (08:58 +0200)
This transform op promotes loops with one iteration. I.e., the loop op is replaced by just the loop body.

Differential Revision: https://reviews.llvm.org/D154361

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/test/Dialect/SCF/transform-ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 9437f89..26cc9b1 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 
 namespace mlir {
 namespace func {
index b9d2dd1..55a378d 100644 (file)
@@ -176,6 +176,36 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
   }];
 }
 
+def LoopPromoteIfOneIterationOp : Op<Transform_Dialect,
+    "loop.promote_if_one_iteration", [
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Promote loop if it has one iteration";
+  let description = [{
+    Promotes the given target loop op if it has a single iteration. I.e., the
+    loop op is removed and only the body remains.
+
+    #### Return modes
+
+    This transform fails if the target is mapped to ops that are loops. Ops are
+    considered loops if they implement the `LoopLikeOpInterface`. Otherwise,
+    this transform always succeeds. The transform consumes the target handle and
+    modifies the payload.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::LoopLikeOpInterface target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
index 88e4562..1d6f9eb 100644 (file)
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
   MLIRAffineDialect
   MLIRFuncDialect
   MLIRIR
+  MLIRLoopLikeInterface
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRSCFUtils
index 43c7d7e..0ca3272 100644 (file)
@@ -216,6 +216,24 @@ transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
 }
 
 //===----------------------------------------------------------------------===//
+// LoopPromoteIfOneIterationOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
+    transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  (void)target.promoteIfSingleIteration(rewriter);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::LoopPromoteIfOneIterationOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
 // LoopUnrollOp
 //===----------------------------------------------------------------------===//
 
index 28c40ef..8f8a054 100644 (file)
@@ -258,3 +258,25 @@ transform.sequence failures(propagate) {
   transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
   transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for">
 }
+
+// -----
+
+// CHECK-LABEL: func @test_promote_if_one_iteration(
+//   CHECK-NOT:   scf.for
+//       CHECK:   %[[r:.*]] = "test.foo"
+//       CHECK:   return %[[r]]
+func.func @test_promote_if_one_iteration(%a: index) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.for %j = %c0 to %c1 step %c1 iter_args(%arg0 = %a) -> index {
+    %1 = "test.foo"(%a) : (index) -> (index)
+    scf.yield %1 : index
+  }
+  return %0 : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.loop.promote_if_one_iteration %0 : !transform.any_op
+}
index b09c410..17c1c2b 100644 (file)
@@ -2409,6 +2409,7 @@ cc_library(
         ":AffineUtils",
         ":FuncDialect",
         ":IR",
+        ":LoopLikeInterface",
         ":SCFDialect",
         ":SCFTransformOpsIncGen",
         ":SCFTransforms",