From 8b8e62d3f6807e8dadd9aef93ea64663f7ec805c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 4 Jul 2023 08:57:55 +0200 Subject: [PATCH] [mlir][SCF] Add `loop.promote_if_one_iteration` transform op This transform op promotes loops with one iteration. I.e., the loop op is replaced by just the loop body. Differential Revision: https://reviews.llvm.org/D154361 --- .../Dialect/SCF/TransformOps/SCFTransformOps.h | 1 + .../Dialect/SCF/TransformOps/SCFTransformOps.td | 30 ++++++++++++++++++++++ mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt | 1 + .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 18 +++++++++++++ mlir/test/Dialect/SCF/transform-ops.mlir | 22 ++++++++++++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 6 files changed, 73 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h index 9437f89..26cc9b1 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/LoopLikeInterface.h" namespace mlir { namespace func { diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index b9d2dd1..55a378d 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -176,6 +176,36 @@ def LoopPipelineOp : Op, + TransformOpInterface, TransformEachOpTrait]> { + let summary = "Promote loop if it has one iteration"; + let description = [{ + Promotes the given target loop op if it has a single iteration. I.e., the + loop op is removed and only the body remains. + + #### Return modes + + This transform fails if the target is mapped to ops that are loops. Ops are + considered loops if they implement the `LoopLikeOpInterface`. Otherwise, + this transform always succeeds. The transform consumes the target handle and + modifies the payload. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::LoopLikeOpInterface target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def LoopUnrollOp : Op { diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt index 88e4562..1d6f9eb 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps MLIRAffineDialect MLIRFuncDialect MLIRIR + MLIRLoopLikeInterface MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 43c7d7e..0ca3272 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -216,6 +216,24 @@ transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, } //===----------------------------------------------------------------------===// +// LoopPromoteIfOneIterationOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( + transform::TransformRewriter &rewriter, LoopLikeOpInterface target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + (void)target.promoteIfSingleIteration(rewriter); + return DiagnosedSilenceableFailure::success(); +} + +void transform::LoopPromoteIfOneIterationOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// // LoopUnrollOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index 28c40ef..8f8a054 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -258,3 +258,25 @@ transform.sequence failures(propagate) { transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for"> } + +// ----- + +// CHECK-LABEL: func @test_promote_if_one_iteration( +// CHECK-NOT: scf.for +// CHECK: %[[r:.*]] = "test.foo" +// CHECK: return %[[r]] +func.func @test_promote_if_one_iteration(%a: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.for %j = %c0 to %c1 step %c1 iter_args(%arg0 = %a) -> index { + %1 = "test.foo"(%a) : (index) -> (index) + scf.yield %1 : index + } + return %0 : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.loop.promote_if_one_iteration %0 : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index b09c410..17c1c2b 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2409,6 +2409,7 @@ cc_library( ":AffineUtils", ":FuncDialect", ":IR", + ":LoopLikeInterface", ":SCFDialect", ":SCFTransformOpsIncGen", ":SCFTransforms", -- 2.7.4