[mlir] Introduce Transform ops for loops
authorAlex Zinenko <zinenko@google.com>
Thu, 9 Jun 2022 09:10:32 +0000 (11:10 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 9 Jun 2022 09:41:55 +0000 (11:41 +0200)
Introduce transform ops for "for" loops, in particular for peeling, software
pipelining and unrolling, along with a couple of "IR navigation" ops. These ops
are intended to be generalized to different kinds of loops when possible and
therefore use the "loop" prefix. They currently live in the SCF dialect as
there is no clear place to put transform ops that may span across several
dialects, this decision is postponed until the ops actually need to handle
non-SCF loops.

Additionally refactor some common utilities for transform ops into trait or
interface methods, and change the loop pipelining to be a returning pattern.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D127300

26 files changed:
mlir/include/mlir/Dialect/SCF/CMakeLists.txt
mlir/include/mlir/Dialect/SCF/Patterns.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/CMakeLists.txt
mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp [new file with mode: 0644]
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/python/CMakeLists.txt
mlir/python/mlir/dialects/SCFLoopTransformOps.td [new file with mode: 0644]
mlir/python/mlir/dialects/_loop_transform_ops_ext.py [new file with mode: 0644]
mlir/python/mlir/dialects/transform/loop.py [new file with mode: 0644]
mlir/test/Dialect/SCF/transform-ops.mlir [new file with mode: 0644]
mlir/test/python/dialects/transform_loop_ext.py [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

index dda48a8..cab8fa0 100644 (file)
@@ -7,3 +7,5 @@ add_public_tablegen_target(MLIRSCFPassIncGen)
 add_dependencies(mlir-headers MLIRSCFPassIncGen)
 
 add_mlir_doc(Passes SCFPasses ./ -gen-pass-doc)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/SCF/Patterns.h b/mlir/include/mlir/Dialect/SCF/Patterns.h
new file mode 100644 (file)
index 0000000..2b8e6ca
--- /dev/null
@@ -0,0 +1,54 @@
+//===- Patterns.h - SCF dialect rewrite patterns ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_PATTERNS_H
+#define MLIR_DIALECT_SCF_PATTERNS_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace scf {
+/// Generate a pipelined version of the scf.for loop based on the schedule given
+/// as option. This applies the mechanical transformation of changing the loop
+/// and generating the prologue/epilogue for the pipelining and doesn't make any
+/// decision regarding the schedule.
+/// Based on the options the loop is split into several stages.
+/// The transformation assumes that the scheduling given by user is valid.
+/// For example if we break a loop into 3 stages named S0, S1, S2 we would
+/// generate the following code with the number in parenthesis as the iteration
+/// index:
+/// S0(0)                        // Prologue
+/// S0(1) S1(0)                  // Prologue
+/// scf.for %I = %C0 to %N - 2 {
+///  S0(I+2) S1(I+1) S2(I)       // Pipelined kernel
+/// }
+/// S1(N) S2(N-1)                // Epilogue
+/// S2(N)                        // Epilogue
+class ForLoopPipeliningPattern : public OpRewritePattern<ForOp> {
+public:
+  ForLoopPipeliningPattern(const PipeliningOption &options,
+                           MLIRContext *context)
+      : OpRewritePattern<ForOp>(context), options(options) {}
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(forOp, rewriter);
+  }
+
+  FailureOr<ForOp> returningMatchAndRewrite(ForOp forOp,
+                                            PatternRewriter &rewriter) const;
+
+protected:
+  PipeliningOption options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_PATTERNS_H
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt
new file mode 100644 (file)
index 0000000..b8e09b6
--- /dev/null
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS SCFTransformOps.td)
+mlir_tablegen(SCFTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(SCFTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRSCFTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
new file mode 100644 (file)
index 0000000..49da5d9
--- /dev/null
@@ -0,0 +1,36 @@
+//===- SCFTransformOps.h - SCF transformation ops ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
+#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace func {
+class FuncOp;
+} // namespace func
+namespace scf {
+class ForOp;
+} // namespace scf
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace scf {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
new file mode 100644 (file)
index 0000000..47a5669
--- /dev/null
@@ -0,0 +1,144 @@
+//===- SCFTransformOps.td - SCF (loop) transformation ops --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SCF_TRANSFORM_OPS
+#define SCF_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
+    [NavigationTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Gets a handle to the parent 'for' loop of the given operation";
+  let description = [{
+    Produces a handle to the n-th (default 1) parent `scf.for` loop for each
+    Payload IR operation associated with the operand. Fails if such a loop
+    cannot be found. The list of operations associated with the handle contains
+    parent operations in the same order as the list associated with the operand,
+    except for operations that are parents to more than one input which are only
+    present once.
+  }];
+
+  let arguments =
+    (ins PDL_Operation:$target,
+         DefaultValuedAttr<Confined<I64Attr, [IntPositive]>,
+                           "1">:$num_loops);
+  let results = (outs PDL_Operation:$parent);
+
+  let assemblyFormat = "$target attr-dict";
+}
+
+def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Outlines a loop into a named function";
+  let description = [{
+     Moves the loop into a separate function with the specified name and
+     replaces the loop in the Payload IR with a call to that function. Takes
+     care of forwarding values that are used in the loop as function arguments.
+     If the operand is associated with more than one loop, each loop will be
+     outlined into a separate function. The provided name is used as a _base_
+     for forming actual function names following SymbolTable auto-renaming
+     scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
+     have a SymbolTable ancestor (typically true because of the top-level 
+     module). Returns the handle to the list of outlined functions in the same
+     order as the operand handle.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   StrAttr:$func_name);
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+}
+
+def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Peels the last iteration of the loop";
+  let description = [{
+     Updates the given loop so that its step evenly divides its range and puts
+     the remaining iteration into a separate loop or a conditional. Note that
+     even though the Payload IR modification may be performed in-place, this
+     operation consumes the operand handle and produces a new one. Applies to
+     each loop associated with the operand handle individually. The results
+     follow the same order as the operand.
+
+     Note: If it can be proven statically that the step already evenly divides
+     the range, this op is a no-op. In the absence of sufficient static
+     information, this op may peel a loop, even if the step always divides the
+     range evenly at runtime.
+  }];
+
+  let arguments =
+      (ins PDL_Operation:$target,
+           DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+  }];
+}
+def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Applies software pipelining to the loop";
+  let description = [{
+    Transforms the given loops one by one to achieve software pipelining for
+    each of them. That is, performs some amount of reads from memory before the
+    loop rather than inside the loop, the same amount of writes into memory
+    after the loop, and updates each iteration to read the data for a following
+    iteration rather than the current one. The amount is specified by the
+    attributes. The values read and about to be stored are transferred as loop
+    iteration arguments. Currently supports memref and vector transfer
+    operations as memory reads/writes.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
+                   DefaultValuedAttr<I64Attr, "10">:$read_latency);
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+  }];
+}
+
+def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls the given loop with the given unroll factor";
+  let description = [{
+     Unrolls each loop associated with the given handle to have up to the given
+     number of loop body copies per iteration. If the unroll factor is larger
+     than the loop trip count, the latter is used as the unroll factor instead.
+     Does not produce a new handle as the operation may result in the loop being
+     removed after a full unrolling.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       Confined<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
+  }];
+}
+
+#endif // SCF_TRANSFORM_OPS
index 228a6b5..7f33b3f 100644 (file)
@@ -158,26 +158,8 @@ struct PipeliningOption {
   // TODO: add option to decide if the prologue should be peeled.
 };
 
-/// Populate patterns for SCF software pipelining transformation.
-/// This transformation generates the pipelined loop and doesn't do any
-/// assumptions on the schedule dictated by the option structure.
-/// Software pipelining is usually done in two part. The first part of
-/// pipelining is to schedule the loop and assign a stage and cycle to each
-/// operations. This is highly dependent on the target and is implemented as an
-/// heuristic based on operation latencies, and other hardware characteristics.
-/// The second part is to take the schedule and generate the pipelined loop as
-/// well as the prologue and epilogue. It is independent of the target.
-/// This pattern only implement the second part.
-/// For example if we break a loop into 3 stages named S0, S1, S2 we would
-/// generate the following code with the number in parenthesis the iteration
-/// index:
-/// S0(0)                        // Prologue
-/// S0(1) S1(0)                  // Prologue
-/// scf.for %I = %C0 to %N - 2 {
-///  S0(I+2) S1(I+1) S2(I)       // Pipelined kernel
-/// }
-/// S1(N) S2(N-1)                // Epilogue
-/// S2(N)                        // Epilogue
+/// Populate patterns for SCF software pipelining transformation. See the
+/// ForLoopPipeliningPattern for the transformation details.
 void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
                                        const PipeliningOption &options);
 
index 09032cb..ebd055c 100644 (file)
@@ -28,6 +28,7 @@ class ValueRange;
 class Value;
 
 namespace func {
+class CallOp;
 class FuncOp;
 } // namespace func
 
@@ -63,12 +64,13 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
 /// `outlinedFuncBody` to alloc simple canonicalizations.
 /// Creates a new FuncOp and thus cannot be used in a FuncOp pass.
 /// The client is responsible for providing a unique `funcName` that will not
-/// collide with another FuncOp name.
+/// collide with another FuncOp name.  If `callOp` is provided, it will be set
+/// to point to the operation that calls the outlined function.
 // TODO: support more than single-block regions.
 // TODO: more flexible constant handling.
-FailureOr<func::FuncOp> outlineSingleBlockRegion(RewriterBase &rewriter,
-                                                 Location loc, Region &region,
-                                                 StringRef funcName);
+FailureOr<func::FuncOp>
+outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region,
+                         StringRef funcName, func::CallOp *callOp = nullptr);
 
 /// Outline the then and/or else regions of `ifOp` as follows:
 ///  - if `thenFn` is not null, `thenFnName` must be specified and the `then`
index d08267c..61a01cb 100644 (file)
@@ -451,7 +451,7 @@ struct PayloadIRResource
   StringRef getName() override { return "transform.payload_ir"; }
 };
 
-/// Trait implementing the MemoryEffectOpInterface for single-operand
+/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or
 /// single-result operations that "consume" their operand and produce a new
 /// result.
 template <typename OpTy>
@@ -468,6 +468,47 @@ public:
     effects.emplace_back(MemoryEffects::Free::get(),
                          this->getOperation()->getOperand(0),
                          TransformMappingResource::get());
+    if (this->getOperation()->getNumResults() == 1) {
+      effects.emplace_back(MemoryEffects::Allocate::get(),
+                           this->getOperation()->getResult(0),
+                           TransformMappingResource::get());
+      effects.emplace_back(MemoryEffects::Write::get(),
+                           this->getOperation()->getResult(0),
+                           TransformMappingResource::get());
+    }
+    effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+  }
+
+  /// Checks that the op matches the expectations of this trait.
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+                  "expected single-operand op");
+    static_assert(OpTy::template hasTrait<OpTrait::ZeroResults>() ||
+                      OpTy::template hasTrait<OpTrait::OneResult>(),
+                  "expected zero- or single-result op");
+    if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
+      op->emitError()
+          << "FunctionalStyleTransformOpTrait should only be attached to ops "
+             "that implement MemoryEffectOpInterface";
+    }
+    return success();
+  }
+};
+
+/// Trait implementing the MemoryEffectOpInterface for single-operand
+/// single-result operations that use their operand without consuming and
+/// without modifying the Payload IR to produce a new handle.
+template <typename OpTy>
+class NavigationTransformOpTrait
+    : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
+public:
+  /// This op produces handles to the Payload IR without consuming the original
+  /// handles and without modifying the IR itself.
+  void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+    effects.emplace_back(MemoryEffects::Read::get(),
+                         this->getOperation()->getOperand(0),
+                         TransformMappingResource::get());
     effects.emplace_back(MemoryEffects::Allocate::get(),
                          this->getOperation()->getResult(0),
                          TransformMappingResource::get());
@@ -475,19 +516,17 @@ public:
                          this->getOperation()->getResult(0),
                          TransformMappingResource::get());
     effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
   }
 
-  /// Checks that the op matches the expectations of this trait.
+  /// Checks that the op matches the expectation of this trait.
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
                   "expected single-operand op");
     static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
                   "expected single-result op");
     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
-      op->emitError()
-          << "FunctionalStyleTransformOpTrait should only be attached to ops "
-             "that implement MemoryEffectOpInterface";
+      op->emitError() << "NavigationTransformOpTrait should only be attached "
+                         "to ops that implement MemoryEffectOpInterface";
     }
     return success();
   }
index fad845c..ff85a74 100644 (file)
@@ -47,6 +47,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
           "::mlir::transform::TransformState &":$state
     )>,
   ];
+
+  let extraSharedClassDeclaration = [{
+    /// Emits a generic transform error for the current transform operation
+    /// targeting the given Payload IR operation and returns failure. Should
+    /// be only used as a last resort when the transformation itself provides
+    /// no further indication as to the reason of the failure.
+    ::mlir::LogicalResult reportUnknownTransformError(
+        ::mlir::Operation *target) {
+      ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply";
+      diag.attachNote(target->getLoc()) << "attempted to apply to this op";
+      return diag;
+    }
+  }];
 }
 
 def FunctionalStyleTransformOpTrait
@@ -58,4 +71,8 @@ def TransformEachOpTrait : NativeOpTrait<"TransformEachOpTrait"> {
   let cppNamespace = "::mlir::transform";
 }
 
+def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
index 9bfd2d3..28ef83e 100644 (file)
@@ -19,7 +19,7 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 
 def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
-     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
   let summary = "Gets handles to the closest isolated-from-above parents";
   let description = [{
     The handles defined by this Transform op correspond to the closest isolated
index 862c996..548ab89 100644 (file)
@@ -47,6 +47,7 @@
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
@@ -107,6 +108,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
 
   // Register all dialect extensions.
   linalg::registerTransformDialectExtension(registry);
+  scf::registerTransformDialectExtension(registry);
 
   // Register all external models.
   arith::registerBufferizableOpInterfaceExternalModels(registry);
index b081e24..b02f933 100644 (file)
@@ -89,9 +89,7 @@ FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
   if (succeeded(depthwise))
     return depthwise;
 
-  InFlightDiagnostic diag = emitError() << "failed to apply";
-  diag.attachNote(target.getLoc()) << "attempted to apply to this op";
-  return diag;
+  return reportUnknownTransformError(target);
 }
 
 //===----------------------------------------------------------------------===//
@@ -107,9 +105,7 @@ FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
   if (succeeded(generic))
     return generic;
 
-  InFlightDiagnostic diag = emitError() << "failed to apply";
-  diag.attachNote(target.getLoc()) << "attempted to apply to this op";
-  return diag;
+  return reportUnknownTransformError(target);
 }
 
 //===----------------------------------------------------------------------===//
@@ -416,11 +412,8 @@ FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
   if (getVectorizePadding())
     linalg::populatePadOpVectorizationPatterns(patterns);
 
-  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
-    InFlightDiagnostic diag = emitError() << "failed to apply";
-    diag.attachNote(target->getLoc()) << "target op";
-    return diag;
-  }
+  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
+    return reportUnknownTransformError(target);
   return target;
 }
 
index 5377286..2301a32 100644 (file)
@@ -16,5 +16,6 @@ add_mlir_dialect_library(MLIRSCF
   MLIRSideEffectInterfaces
   )
 
+add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
new file mode 100644 (file)
index 0000000..611a78f
--- /dev/null
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIRSCFTransformOps
+  SCFTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF/TransformOps
+
+  DEPENDS
+  MLIRSCFTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAffine
+  MLIRFunc
+  MLIRIR
+  MLIRPDL
+  MLIRSCF
+  MLIRSCFTransforms
+  MLIRSCFUtils
+  MLIRTransformDialect
+  MLIRVector
+)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
new file mode 100644 (file)
index 0000000..213e1fd
--- /dev/null
@@ -0,0 +1,232 @@
+//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Patterns.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+
+namespace {
+/// A simple pattern rewriter that implements no special logic.
+class SimpleRewriter : public PatternRewriter {
+public:
+  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GetParentForOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+transform::GetParentForOp::apply(transform::TransformResults &results,
+                                 transform::TransformState &state) {
+  SetVector<Operation *> parents;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    scf::ForOp loop;
+    Operation *current = target;
+    for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
+      loop = current->getParentOfType<scf::ForOp>();
+      if (!loop) {
+        InFlightDiagnostic diag = emitError() << "could not find an '"
+                                              << scf::ForOp::getOperationName()
+                                              << "' parent";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+      current = loop;
+    }
+    parents.insert(loop);
+  }
+  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOutlineOp
+//===----------------------------------------------------------------------===//
+
+/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
+/// the provided rewriter for all operations to remain compatible with the
+/// rewriting infra, as opposed to just splicing the op in place.
+static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
+                                                Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  scf::ExecuteRegionOp executeRegionOp =
+      b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
+  {
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
+    Operation *clonedOp = b.cloneWithoutRegions(*op);
+    Region &clonedRegion = clonedOp->getRegions().front();
+    assert(clonedRegion.empty() && "expected empty region");
+    b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
+                         clonedRegion.end());
+    b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
+  }
+  b.replaceOp(op, executeRegionOp.getResults());
+  return executeRegionOp;
+}
+
+LogicalResult
+transform::LoopOutlineOp::apply(transform::TransformResults &results,
+                                transform::TransformState &state) {
+  SmallVector<Operation *> transformed;
+  DenseMap<Operation *, SymbolTable> symbolTables;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    Location location = target->getLoc();
+    Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
+    SimpleRewriter rewriter(getContext());
+    scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
+    if (!exec) {
+      InFlightDiagnostic diag = emitError() << "failed to outline";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    func::CallOp call;
+    FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
+        rewriter, location, exec.getRegion(), getFuncName(), &call);
+
+    if (failed(outlined))
+      return reportUnknownTransformError(target);
+
+    if (symbolTableOp) {
+      SymbolTable &symbolTable =
+          symbolTables.try_emplace(symbolTableOp, symbolTableOp)
+              .first->getSecond();
+      symbolTable.insert(*outlined);
+      call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
+    }
+    transformed.push_back(*outlined);
+  }
+  results.set(getTransformed().cast<OpResult>(), transformed);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopPeelOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
+  scf::ForOp result;
+  IRRewriter rewriter(loop->getContext());
+  LogicalResult status =
+      scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
+  if (failed(status)) {
+    if (getFailIfAlreadyDivisible())
+      return reportUnknownTransformError(loop);
+    return loop;
+  }
+  return result;
+}
+
+//===----------------------------------------------------------------------===//
+// LoopPipelineOp
+//===----------------------------------------------------------------------===//
+
+/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
+/// operation to its logical time position given the iteration interval and the
+/// read latency. The latter is only relevant for vector transfers.
+static void
+loopScheduling(scf::ForOp forOp,
+               std::vector<std::pair<Operation *, unsigned>> &schedule,
+               unsigned iterationInterval, unsigned readLatency) {
+  auto getLatency = [&](Operation *op) -> unsigned {
+    if (isa<vector::TransferReadOp>(op))
+      return readLatency;
+    return 1;
+  };
+
+  DenseMap<Operation *, unsigned> opCycles;
+  std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    if (isa<scf::YieldOp>(op))
+      continue;
+    unsigned earlyCycle = 0;
+    for (Value operand : op.getOperands()) {
+      Operation *def = operand.getDefiningOp();
+      if (!def)
+        continue;
+      earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
+    }
+    opCycles[&op] = earlyCycle;
+    wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
+  }
+  for (auto it : wrappedSchedule) {
+    for (Operation *op : it.second) {
+      unsigned cycle = opCycles[op];
+      schedule.push_back(std::make_pair(op, cycle / iterationInterval));
+    }
+  }
+}
+
+FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
+  scf::PipeliningOption options;
+  options.getScheduleFn =
+      [this](scf::ForOp forOp,
+             std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
+        loopScheduling(forOp, schedule, getIterationInterval(),
+                       getReadLatency());
+      };
+
+  scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(loop);
+  FailureOr<scf::ForOp> patternResult =
+      pattern.returningMatchAndRewrite(loop, rewriter);
+  if (failed(patternResult))
+    return reportUnknownTransformError(loop);
+  return patternResult;
+}
+
+//===----------------------------------------------------------------------===//
+// LoopUnrollOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
+  if (failed(loopUnrollByFactor(loop, getFactor())))
+    return reportUnknownTransformError(loop);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class SCFTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          SCFTransformDialectExtension> {
+public:
+  SCFTransformDialectExtension() {
+    declareDependentDialect<AffineDialect>();
+    declareDependentDialect<func::FuncDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
+
+void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<SCFTransformDialectExtension>();
+}
index 659d248..cd9e036 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "PassDetail.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/SCF/Patterns.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
@@ -436,76 +437,53 @@ void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
   it->second[idx] = el;
 }
 
-/// Generate a pipelined version of the scf.for loop based on the schedule given
-/// as option. This applies the mechanical transformation of changing the loop
-/// and generating the prologue/epilogue for the pipelining and doesn't make any
-/// decision regarding the schedule.
-/// Based on the option the loop is split into several stages.
-/// The transformation assumes that the scheduling given by user is valid.
-/// For example if we break a loop into 3 stages named S0, S1, S2 we would
-/// generate the following code with the number in parenthesis the iteration
-/// index:
-/// S0(0)                        // Prologue
-/// S0(1) S1(0)                  // Prologue
-/// scf.for %I = %C0 to %N - 2 {
-///  S0(I+2) S1(I+1) S2(I)       // Pipelined kernel
-/// }
-/// S1(N) S2(N-1)                // Epilogue
-/// S2(N)                        // Epilogue
-struct ForLoopPipelining : public OpRewritePattern<ForOp> {
-  ForLoopPipelining(const PipeliningOption &options, MLIRContext *context)
-      : OpRewritePattern<ForOp>(context), options(options) {}
-  LogicalResult matchAndRewrite(ForOp forOp,
-                                PatternRewriter &rewriter) const override {
+} // namespace
 
-    LoopPipelinerInternal pipeliner;
-    if (!pipeliner.initializeLoopInfo(forOp, options))
-      return failure();
+FailureOr<ForOp> ForLoopPipeliningPattern::returningMatchAndRewrite(
+    ForOp forOp, PatternRewriter &rewriter) const {
 
-    // 1. Emit prologue.
-    pipeliner.emitPrologue(rewriter);
+  LoopPipelinerInternal pipeliner;
+  if (!pipeliner.initializeLoopInfo(forOp, options))
+    return failure();
 
-    // 2. Track values used across stages. When a value cross stages it will
-    // need to be passed as loop iteration arguments.
-    // We first collect the values that are used in a different stage than where
-    // they are defined.
-    llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
-        crossStageValues = pipeliner.analyzeCrossStageValues();
+  // 1. Emit prologue.
+  pipeliner.emitPrologue(rewriter);
 
-    // Mapping between original loop values used cross stage and the block
-    // arguments associated after pipelining. A Value may map to several
-    // arguments if its liverange spans across more than 2 stages.
-    llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
-    // 3. Create the new kernel loop and return the block arguments mapping.
-    ForOp newForOp =
-        pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
-    // Create the kernel block, order ops based on user choice and remap
-    // operands.
-    pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
+  // 2. Track values used across stages. When a value cross stages it will
+  // need to be passed as loop iteration arguments.
+  // We first collect the values that are used in a different stage than where
+  // they are defined.
+  llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+      crossStageValues = pipeliner.analyzeCrossStageValues();
 
-    llvm::SmallVector<Value> returnValues =
-        newForOp.getResults().take_front(forOp->getNumResults());
-    if (options.peelEpilogue) {
-      // 4. Emit the epilogue after the new forOp.
-      rewriter.setInsertionPointAfter(newForOp);
-      returnValues = pipeliner.emitEpilogue(rewriter);
-    }
-    // 5. Erase the original loop and replace the uses with the epilogue output.
-    if (forOp->getNumResults() > 0)
-      rewriter.replaceOp(forOp, returnValues);
-    else
-      rewriter.eraseOp(forOp);
+  // Mapping between original loop values used cross stage and the block
+  // arguments associated after pipelining. A Value may map to several
+  // arguments if its liverange spans across more than 2 stages.
+  llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
+  // 3. Create the new kernel loop and return the block arguments mapping.
+  ForOp newForOp =
+      pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
+  // Create the kernel block, order ops based on user choice and remap
+  // operands.
+  pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
 
-    return success();
+  llvm::SmallVector<Value> returnValues =
+      newForOp.getResults().take_front(forOp->getNumResults());
+  if (options.peelEpilogue) {
+    // 4. Emit the epilogue after the new forOp.
+    rewriter.setInsertionPointAfter(newForOp);
+    returnValues = pipeliner.emitEpilogue(rewriter);
   }
+  // 5. Erase the original loop and replace the uses with the epilogue output.
+  if (forOp->getNumResults() > 0)
+    rewriter.replaceOp(forOp, returnValues);
+  else
+    rewriter.eraseOp(forOp);
 
-protected:
-  PipeliningOption options;
-};
-
-} // namespace
+  return newForOp;
+}
 
 void mlir::scf::populateSCFLoopPipeliningPatterns(
     RewritePatternSet &patterns, const PipeliningOption &options) {
-  patterns.add<ForLoopPipelining>(options, patterns.getContext());
+  patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext());
 }
index 0910cb3..bce73bd 100644 (file)
@@ -105,13 +105,16 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
 /// This method also clones the `arith::ConstantIndexOp` at the start of
-/// `outlinedFuncBody` to alloc simple canonicalizations.
+/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
+/// provided, it will be set to point to the operation that calls the outlined
+/// function.
 // TODO: support more than single-block regions.
 // TODO: more flexible constant handling.
 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
                                                        Location loc,
                                                        Region &region,
-                                                       StringRef funcName) {
+                                                       StringRef funcName,
+                                                       func::CallOp *callOp) {
   assert(!funcName.empty() && "funcName cannot be empty");
   if (!region.hasOneBlock())
     return failure();
@@ -176,8 +179,9 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
     SmallVector<Value> callValues;
     llvm::append_range(callValues, newBlock->getArguments());
     llvm::append_range(callValues, outlinedValues);
-    Operation *call =
-        rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
+    auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
+    if (callOp)
+      *callOp = call;
 
     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
     // Clone `originalTerminator` to take the callOp results then erase it from
index 91ad5ff..9dadd26 100644 (file)
@@ -137,17 +137,6 @@ LogicalResult transform::GetClosestIsolatedParentOp::apply(
   return success();
 }
 
-void transform::GetClosestIsolatedParentOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
-                       TransformMappingResource::get());
-  effects.emplace_back(MemoryEffects::Allocate::get(), getParent(),
-                       TransformMappingResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), getParent(),
-                       TransformMappingResource::get());
-  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
-}
-
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
index 17048e8..13b35f1 100644 (file)
@@ -128,6 +128,16 @@ declare_mlir_dialect_python_bindings(
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/SCFLoopTransformOps.td
+  SOURCES
+    dialects/_loop_transform_ops_ext.py
+    dialects/transform/loop.py
+  DIALECT_NAME transform
+  EXTENSION_NAME loop_transform)
+
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/LinalgStructuredTransformOps.td
   SOURCES
     dialects/_structured_transform_ops_ext.py
diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td
new file mode 100644 (file)
index 0000000..5ef07fc
--- /dev/null
@@ -0,0 +1,21 @@
+//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the loop transform ops
+// provided by the SCF (and other) dialects.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
+#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td"
+
+#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
new file mode 100644 (file)
index 0000000..7452c42
--- /dev/null
@@ -0,0 +1,113 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+  from ..dialects import pdl
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Union
+
+
+def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]],
+                    default_value: int = None):
+  if isinstance(arg, IntegerAttr):
+    return arg
+
+  if arg is None:
+    assert default_value is not None, "must provide default value"
+    arg = default_value
+
+  return IntegerAttr.get(IntegerType.get_signless(64), arg)
+
+
+class GetParentForOp:
+  """Extension for GetParentForOp."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               num_loops: int = 1,
+               ip=None,
+               loc=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        num_loops=_get_int64_attr(num_loops, default_value=1),
+        ip=ip,
+        loc=loc)
+
+
+class LoopOutlineOp:
+  """Extension for LoopOutlineOp."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               func_name: Union[str, StringAttr],
+               ip=None,
+               loc=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        func_name=(func_name if isinstance(func_name, StringAttr) else
+                   StringAttr.get(func_name)),
+        ip=ip,
+        loc=loc)
+
+
+class LoopPeelOp:
+  """Extension for LoopPeelOp."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               fail_if_already_divisible: Union[bool, BoolAttr] = False,
+               ip=None,
+               loc=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        fail_if_already_divisible=(fail_if_already_divisible if isinstance(
+            fail_if_already_divisible, BoolAttr) else
+                                   BoolAttr.get(fail_if_already_divisible)),
+        ip=ip,
+        loc=loc)
+
+
+class LoopPipelineOp:
+  """Extension for LoopPipelineOp."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+               read_latency: Optional[Union[int, IntegerAttr]] = None,
+               ip=None,
+               loc=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
+        read_latency=_get_int64_attr(read_latency, default_value=10),
+        ip=ip,
+        loc=loc)
+
+
+class LoopUnrollOp:
+  """Extension for LoopUnrollOp."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               factor: Union[int, IntegerAttr],
+               ip=None,
+               loc=None):
+    super().__init__(
+        _get_op_result_or_value(target),
+        factor=_get_int64_attr(factor),
+        ip=ip,
+        loc=loc)
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
new file mode 100644 (file)
index 0000000..86f7278
--- /dev/null
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._loop_transform_ops_gen import *
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
new file mode 100644 (file)
index 0000000..deaf367
--- /dev/null
@@ -0,0 +1,264 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @get_parent_for_op
+func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
+  // expected-remark @below {{first loop}}
+  scf.for %i = %arg0 to %arg1 step %arg2 {
+    // expected-remark @below {{second loop}}
+    scf.for %j = %arg0 to %arg1 step %arg2 {
+      // expected-remark @below {{third loop}}
+      scf.for %k = %arg0 to %arg1 step %arg2 {
+        arith.addi %i, %j : index
+      }
+    }
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addi : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addi in %arg1
+    // CHECK: = transform.loop.get_parent_for
+    %1 = transform.loop.get_parent_for %0
+    %2 = transform.loop.get_parent_for %0 { num_loops = 2 }
+    %3 = transform.loop.get_parent_for %0 { num_loops = 3 }
+    transform.test_print_remark_at_operand %1, "third loop"
+    transform.test_print_remark_at_operand %2, "second loop"
+    transform.test_print_remark_at_operand %3, "first loop"
+  }
+}
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+  // expected-note @below {{target op}}
+  arith.addi %arg0, %arg1 : index  
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addi : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addi in %arg1
+    // expected-error @below {{could not find an 'scf.for' parent}}
+    %1 = transform.loop.get_parent_for %0
+  }
+}
+
+// -----
+
+// Outlined functions:
+//
+// CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}})
+// CHECK:   scf.for
+// CHECK:     arith.addi
+//
+// CHECK: func @foo[[SUFFIX:.+]](%{{.+}}, %{{.+}}, %{{.+}})
+// CHECK:   scf.for
+// CHECK:     arith.addi
+//
+// CHECK-LABEL @loop_outline_op
+func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
+  // CHECK: scf.for
+  // CHECK-NOT: scf.for
+  // CHECK:   scf.execute_region
+  // CHECK:     func.call @foo
+  scf.for %i = %arg0 to %arg1 step %arg2 {
+    scf.for %j = %arg0 to %arg1 step %arg2 {
+      arith.addi %i, %j : index
+    }
+  }
+  // CHECK: scf.execute_region
+  // CHECK-NOT: scf.for
+  // CHECK:   func.call @foo[[SUFFIX]]
+  scf.for %j = %arg0 to %arg1 step %arg2 {
+    arith.addi %j, %j : index
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addi : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addi in %arg1
+    %1 = transform.loop.get_parent_for %0
+    // CHECK: = transform.loop.outline %{{.*}}
+    transform.loop.outline %1 {func_name = "foo"}
+  }
+}
+
+// -----
+
+func.func private @cond() -> i1
+func.func private @body()
+
+func.func @loop_outline_op_multi_region() {
+  // expected-note @below {{target op}}
+  scf.while : () -> () {
+    %0 = func.call @cond() : () -> i1
+    scf.condition(%0)
+  } do {
+  ^bb0:
+    func.call @body() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_while : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "scf.while"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_while in %arg1
+    // expected-error @below {{failed to outline}}
+    transform.loop.outline %0 {func_name = "foo"}
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_peel_op
+func.func @loop_peel_op() {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[C42:.+]] = arith.constant 42
+  // CHECK: %[[C5:.+]] = arith.constant 5
+  // CHECK: %[[C40:.+]] = arith.constant 40
+  // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]]
+  // CHECK:   arith.addi
+  // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]]
+  // CHECK:   arith.addi
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 42 : index
+  %2 = arith.constant 5 : index
+  scf.for %i = %0 to %1 step %2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addi : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addi in %arg1
+    %1 = transform.loop.get_parent_for %0
+    transform.loop.peel %1
+  }
+}
+
+// -----
+
+func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  // CHECK: memref.load %[[MEMREF:.+]][%{{.+}}]
+  // CHECK: memref.load %[[MEMREF]]
+  // CHECK: arith.addf
+  // CHECK: scf.for
+  // CHECK:   memref.load
+  // CHECK:   arith.addf
+  // CHECK:   memref.store
+  // CHECK: arith.addf
+  // CHECK: memref.store
+  // CHECK: memref.store
+  // expected-remark @below {{transformed}}
+  scf.for %i0 = %c0 to %c4 step %c1 {
+    %A_elem = memref.load %A[%i0] : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %cf : f32
+    memref.store %A1_elem, %result[%i0] : memref<?xf32>
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addf : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addf in %arg1
+    %1 = transform.loop.get_parent_for %0
+    %2 = transform.loop.pipeline %1
+    // Verify that the returned handle is usable.
+    transform.test_print_remark_at_operand %2, "transformed"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_op
+func.func @loop_unroll_op() {
+  %c0 = arith.constant 0 : index
+  %c42 = arith.constant 42 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: scf.for %[[I:.+]] =
+  scf.for %i = %c0 to %c42 step %c5 {
+    // CHECK-COUNT-4: arith.addi %[[I]]
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_addi : benefit(1) {
+    %args = operands
+    %results = types
+    %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %op with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_addi in %arg1
+    %1 = transform.loop.get_parent_for %0
+    transform.loop.unroll %1 { factor = 4 }
+  }
+}
+
diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
new file mode 100644 (file)
index 0000000..a324266
--- /dev/null
@@ -0,0 +1,71 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects import pdl
+from mlir.dialects.transform import loop
+
+
+def run(f):
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      print("\nTEST:", f.__name__)
+      f()
+    print(module)
+  return f
+
+
+@run
+def getParentLoop():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: getParentLoop
+  # CHECK: = transform.loop.get_parent_for %
+  # CHECK: num_loops = 2
+
+
+@run
+def loopOutline():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: loopOutline
+  # CHECK: = transform.loop.outline %
+  # CHECK: func_name = "foo"
+
+
+@run
+def loopPeel():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    loop.LoopPeelOp(sequence.bodyTarget)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: loopPeel
+  # CHECK: = transform.loop.peel %
+
+
+@run
+def loopPipeline():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: loopPipeline
+  # CHECK: = transform.loop.pipeline %
+  # CHECK-DAG: iteration_interval = 3
+  # CHECK-DAG: read_latency = 10
+
+
+@run
+def loopUnroll():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: loopUnroll
+  # CHECK: transform.loop.unroll %
+  # CHECK: factor = 42
index 516b1ec..b511d48 100644 (file)
@@ -1868,6 +1868,7 @@ cc_library(
     hdrs = [
         "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/SCF/Passes.h",
+        "include/mlir/Dialect/SCF/Patterns.h",
         "include/mlir/Dialect/SCF/Transforms.h",
     ],
     includes = ["include"],
@@ -1892,6 +1893,59 @@ cc_library(
     ],
 )
 
+td_library(
+    name = "SCFTransformOpsTdFiles",
+    srcs = [
+        "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":PDLDialect",
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "SCFTransformOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td",
+    deps = [
+        ":SCFTransformOpsTdFiles",
+    ],
+)
+
+cc_library(
+    name = "SCFTransformOps",
+    srcs = glob(["lib/Dialect/SCF/TransformOps/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/SCF/TransformOps/*.h"]),
+    includes = ["include"],
+    deps = [
+        ":Affine",
+        ":FuncDialect",
+        ":IR",
+        ":PDLDialect",
+        ":SCFDialect",
+        ":SCFTransformOpsIncGen",
+        ":SCFTransforms",
+        ":SCFUtils",
+        ":SideEffectInterfaces",
+        ":TransformDialect",
+        ":VectorOps",
+        "//llvm:Support",
+    ],
+)
+
 ##---------------------------------------------------------------------------##
 # SparseTensor dialect.
 ##---------------------------------------------------------------------------##
@@ -2601,6 +2655,7 @@ cc_library(
         ],
         exclude = [
             "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
+            "include/mlir/Dialect/SCF/Patterns.h",
             "include/mlir/Dialect/SCF/Transforms.h",
         ],
     ),
@@ -6299,6 +6354,7 @@ cc_library(
         ":SCFPassIncGen",
         ":SCFToGPUPass",
         ":SCFToStandard",
+        ":SCFTransformOps",
         ":SCFTransforms",
         ":SDBM",
         ":SPIRVDialect",
index c94bc5d..c013492 100644 (file)
@@ -878,11 +878,33 @@ gentbl_filegroup(
     ],
 )
 
+gentbl_filegroup(
+    name = "LoopTransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+                "-dialect-extension=loop_transform",
+            ],
+            "mlir/dialects/_loop_transform_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/SCFLoopTransformOps.td",
+    deps = [
+        ":TransformOpsPyTdFiles",
+        "//mlir:SCFTransformOpsTdFiles",
+    ],
+)
+
 filegroup(
     name = "TransformOpsPyFiles",
     srcs = [
+        "mlir/dialects/_loop_transform_ops_ext.py",
         "mlir/dialects/_structured_transform_ops_ext.py",
         "mlir/dialects/_transform_ops_ext.py",
+        ":LoopTransformOpsPyGen",
         ":StructuredTransformOpsPyGen",
         ":TransformOpsPyGen",
     ],