From 5f0d4f208e24a3e9f7369b712c5c2598dd5582d4 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 9 Jun 2022 11:10:32 +0200 Subject: [PATCH] [mlir] Introduce Transform ops for loops Introduce transform ops for "for" loops, in particular for peeling, software pipelining and unrolling, along with a couple of "IR navigation" ops. These ops are intended to be generalized to different kinds of loops when possible and therefore use the "loop" prefix. They currently live in the SCF dialect as there is no clear place to put transform ops that may span across several dialects, this decision is postponed until the ops actually need to handle non-SCF loops. Additionally refactor some common utilities for transform ops into trait or interface methods, and change the loop pipelining to be a returning pattern. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D127300 --- mlir/include/mlir/Dialect/SCF/CMakeLists.txt | 2 + mlir/include/mlir/Dialect/SCF/Patterns.h | 54 +++++ .../mlir/Dialect/SCF/TransformOps/CMakeLists.txt | 4 + .../Dialect/SCF/TransformOps/SCFTransformOps.h | 36 +++ .../Dialect/SCF/TransformOps/SCFTransformOps.td | 144 +++++++++++ mlir/include/mlir/Dialect/SCF/Transforms.h | 22 +- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 10 +- .../Dialect/Transform/IR/TransformInterfaces.h | 51 +++- .../Dialect/Transform/IR/TransformInterfaces.td | 17 ++ .../mlir/Dialect/Transform/IR/TransformOps.td | 2 +- mlir/include/mlir/InitAllDialects.h | 2 + .../Linalg/TransformOps/LinalgTransformOps.cpp | 15 +- mlir/lib/Dialect/SCF/CMakeLists.txt | 1 + mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt | 20 ++ .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 232 ++++++++++++++++++ mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 100 +++----- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 12 +- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 11 - mlir/python/CMakeLists.txt | 10 + mlir/python/mlir/dialects/SCFLoopTransformOps.td | 21 ++ .../mlir/dialects/_loop_transform_ops_ext.py | 113 +++++++++ mlir/python/mlir/dialects/transform/loop.py | 5 + mlir/test/Dialect/SCF/transform-ops.mlir | 264 +++++++++++++++++++++ mlir/test/python/dialects/transform_loop_ext.py | 71 ++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 56 +++++ .../llvm-project-overlay/mlir/python/BUILD.bazel | 22 ++ 26 files changed, 1179 insertions(+), 118 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SCF/Patterns.h create mode 100644 mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h create mode 100644 mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td create mode 100644 mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp create mode 100644 mlir/python/mlir/dialects/SCFLoopTransformOps.td create mode 100644 mlir/python/mlir/dialects/_loop_transform_ops_ext.py create mode 100644 mlir/python/mlir/dialects/transform/loop.py create mode 100644 mlir/test/Dialect/SCF/transform-ops.mlir create mode 100644 mlir/test/python/dialects/transform_loop_ext.py diff --git a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt index dda48a8..cab8fa0 100644 --- a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt @@ -7,3 +7,5 @@ add_public_tablegen_target(MLIRSCFPassIncGen) add_dependencies(mlir-headers MLIRSCFPassIncGen) add_mlir_doc(Passes SCFPasses ./ -gen-pass-doc) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/SCF/Patterns.h b/mlir/include/mlir/Dialect/SCF/Patterns.h new file mode 100644 index 0000000..2b8e6ca --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/Patterns.h @@ -0,0 +1,54 @@ +//===- Patterns.h - SCF dialect rewrite patterns ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_PATTERNS_H +#define MLIR_DIALECT_SCF_PATTERNS_H + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace scf { +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +class ForLoopPipeliningPattern : public OpRewritePattern { +public: + ForLoopPipeliningPattern(const PipeliningOption &options, + MLIRContext *context) + : OpRewritePattern(context), options(options) {} + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(forOp, rewriter); + } + + FailureOr returningMatchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const; + +protected: + PipeliningOption options; +}; + +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_PATTERNS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..b8e09b6 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS SCFTransformOps.td) +mlir_tablegen(SCFTransformOps.h.inc -gen-op-decls) +mlir_tablegen(SCFTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRSCFTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h new file mode 100644 index 0000000..49da5d9 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -0,0 +1,36 @@ +//===- SCFTransformOps.h - SCF transformation ops ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H +#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func +namespace scf { +class ForOp; +} // namespace scf +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace scf { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td new file mode 100644 index 0000000..47a5669 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -0,0 +1,144 @@ +//===- SCFTransformOps.td - SCF (loop) transformation ops --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SCF_TRANSFORM_OPS +#define SCF_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def GetParentForOp : Op]> { + let summary = "Gets a handle to the parent 'for' loop of the given operation"; + let description = [{ + Produces a handle to the n-th (default 1) parent `scf.for` loop for each + Payload IR operation associated with the operand. Fails if such a loop + cannot be found. The list of operations associated with the handle contains + parent operations in the same order as the list associated with the operand, + except for operations that are parents to more than one input which are only + present once. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr, + "1">:$num_loops); + let results = (outs PDL_Operation:$parent); + + let assemblyFormat = "$target attr-dict"; +} + +def LoopOutlineOp : Op]> { + let summary = "Outlines a loop into a named function"; + let description = [{ + Moves the loop into a separate function with the specified name and + replaces the loop in the Payload IR with a call to that function. Takes + care of forwarding values that are used in the loop as function arguments. + If the operand is associated with more than one loop, each loop will be + outlined into a separate function. The provided name is used as a _base_ + for forming actual function names following SymbolTable auto-renaming + scheme to avoid duplicate symbols. Expects that all ops in the Payload IR + have a SymbolTable ancestor (typically true because of the top-level + module). Returns the handle to the list of outlined functions in the same + order as the operand handle. + }]; + + let arguments = (ins PDL_Operation:$target, + StrAttr:$func_name); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; +} + +def LoopPeelOp : Op { + let summary = "Peels the last iteration of the loop"; + let description = [{ + Updates the given loop so that its step evenly divides its range and puts + the remaining iteration into a separate loop or a conditional. Note that + even though the Payload IR modification may be performed in-place, this + operation consumes the operand handle and produces a new one. Applies to + each loop associated with the operand handle individually. The results + follow the same order as the operand. + + Note: If it can be proven statically that the step already evenly divides + the range, this op is a no-op. In the absence of sufficient static + information, this op may peel a loop, even if the step always divides the + range evenly at runtime. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr:$fail_if_already_divisible); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + }]; +} + +def LoopPipelineOp : Op { + let summary = "Applies software pipelining to the loop"; + let description = [{ + Transforms the given loops one by one to achieve software pipelining for + each of them. That is, performs some amount of reads from memory before the + loop rather than inside the loop, the same amount of writes into memory + after the loop, and updates each iteration to read the data for a following + iteration rather than the current one. The amount is specified by the + attributes. The values read and about to be stored are transferred as loop + iteration arguments. Currently supports memref and vector transfer + operations as memory reads/writes. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$iteration_interval, + DefaultValuedAttr:$read_latency); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + }]; +} + +def LoopUnrollOp : Op { + let summary = "Unrolls the given loop with the given unroll factor"; + let description = [{ + Unrolls each loop associated with the given handle to have up to the given + number of loop body copies per iteration. If the unroll factor is larger + than the loop trip count, the latter is used as the unroll factor instead. + Does not produce a new handle as the operation may result in the loop being + removed after a full unrolling. + }]; + + let arguments = (ins PDL_Operation:$target, + Confined:$factor); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop); + }]; +} + +#endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h index 228a6b5..7f33b3f 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -158,26 +158,8 @@ struct PipeliningOption { // TODO: add option to decide if the prologue should be peeled. }; -/// Populate patterns for SCF software pipelining transformation. -/// This transformation generates the pipelined loop and doesn't do any -/// assumptions on the schedule dictated by the option structure. -/// Software pipelining is usually done in two part. The first part of -/// pipelining is to schedule the loop and assign a stage and cycle to each -/// operations. This is highly dependent on the target and is implemented as an -/// heuristic based on operation latencies, and other hardware characteristics. -/// The second part is to take the schedule and generate the pipelined loop as -/// well as the prologue and epilogue. It is independent of the target. -/// This pattern only implement the second part. -/// For example if we break a loop into 3 stages named S0, S1, S2 we would -/// generate the following code with the number in parenthesis the iteration -/// index: -/// S0(0) // Prologue -/// S0(1) S1(0) // Prologue -/// scf.for %I = %C0 to %N - 2 { -/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel -/// } -/// S1(N) S2(N-1) // Epilogue -/// S2(N) // Epilogue +/// Populate patterns for SCF software pipelining transformation. See the +/// ForLoopPipeliningPattern for the transformation details. void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options); diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 09032cb..ebd055c 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -28,6 +28,7 @@ class ValueRange; class Value; namespace func { +class CallOp; class FuncOp; } // namespace func @@ -63,12 +64,13 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, /// `outlinedFuncBody` to alloc simple canonicalizations. /// Creates a new FuncOp and thus cannot be used in a FuncOp pass. /// The client is responsible for providing a unique `funcName` that will not -/// collide with another FuncOp name. +/// collide with another FuncOp name. If `callOp` is provided, it will be set +/// to point to the operation that calls the outlined function. // TODO: support more than single-block regions. // TODO: more flexible constant handling. -FailureOr outlineSingleBlockRegion(RewriterBase &rewriter, - Location loc, Region ®ion, - StringRef funcName); +FailureOr +outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, + StringRef funcName, func::CallOp *callOp = nullptr); /// Outline the then and/or else regions of `ifOp` as follows: /// - if `thenFn` is not null, `thenFnName` must be specified and the `then` diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index d08267c..61a01cb 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -451,7 +451,7 @@ struct PayloadIRResource StringRef getName() override { return "transform.payload_ir"; } }; -/// Trait implementing the MemoryEffectOpInterface for single-operand +/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or /// single-result operations that "consume" their operand and produce a new /// result. template @@ -468,6 +468,47 @@ public: effects.emplace_back(MemoryEffects::Free::get(), this->getOperation()->getOperand(0), TransformMappingResource::get()); + if (this->getOperation()->getNumResults() == 1) { + effects.emplace_back(MemoryEffects::Allocate::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); + } + + /// Checks that the op matches the expectations of this trait. + static LogicalResult verifyTrait(Operation *op) { + static_assert(OpTy::template hasTrait(), + "expected single-operand op"); + static_assert(OpTy::template hasTrait() || + OpTy::template hasTrait(), + "expected zero- or single-result op"); + if (!op->getName().getInterface()) { + op->emitError() + << "FunctionalStyleTransformOpTrait should only be attached to ops " + "that implement MemoryEffectOpInterface"; + } + return success(); + } +}; + +/// Trait implementing the MemoryEffectOpInterface for single-operand +/// single-result operations that use their operand without consuming and +/// without modifying the Payload IR to produce a new handle. +template +class NavigationTransformOpTrait + : public OpTrait::TraitBase { +public: + /// This op produces handles to the Payload IR without consuming the original + /// handles and without modifying the IR itself. + void getEffects(SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), + this->getOperation()->getOperand(0), + TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), this->getOperation()->getResult(0), TransformMappingResource::get()); @@ -475,19 +516,17 @@ public: this->getOperation()->getResult(0), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); } - /// Checks that the op matches the expectations of this trait. + /// Checks that the op matches the expectation of this trait. static LogicalResult verifyTrait(Operation *op) { static_assert(OpTy::template hasTrait(), "expected single-operand op"); static_assert(OpTy::template hasTrait(), "expected single-result op"); if (!op->getName().getInterface()) { - op->emitError() - << "FunctionalStyleTransformOpTrait should only be attached to ops " - "that implement MemoryEffectOpInterface"; + op->emitError() << "NavigationTransformOpTrait should only be attached " + "to ops that implement MemoryEffectOpInterface"; } return success(); } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index fad845c..ff85a74 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -47,6 +47,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> { "::mlir::transform::TransformState &":$state )>, ]; + + let extraSharedClassDeclaration = [{ + /// Emits a generic transform error for the current transform operation + /// targeting the given Payload IR operation and returns failure. Should + /// be only used as a last resort when the transformation itself provides + /// no further indication as to the reason of the failure. + ::mlir::LogicalResult reportUnknownTransformError( + ::mlir::Operation *target) { + ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply"; + diag.attachNote(target->getLoc()) << "attempted to apply to this op"; + return diag; + } + }]; } def FunctionalStyleTransformOpTrait @@ -58,4 +71,8 @@ def TransformEachOpTrait : NativeOpTrait<"TransformEachOpTrait"> { let cppNamespace = "::mlir::transform"; } +def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 9bfd2d3..28ef83e 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -19,7 +19,7 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td" def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { let summary = "Gets handles to the closest isolated-from-above parents"; let description = [{ The handles defined by this Transform op correspond to the closest isolated diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 862c996..548ab89 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -47,6 +47,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" @@ -107,6 +108,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { // Register all dialect extensions. linalg::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); // Register all external models. arith::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index b081e24..b02f933 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -89,9 +89,7 @@ FailureOr transform::DecomposeOp::applyToOne(LinalgOp target) { if (succeeded(depthwise)) return depthwise; - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + return reportUnknownTransformError(target); } //===----------------------------------------------------------------------===// @@ -107,9 +105,7 @@ FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target) { if (succeeded(generic)) return generic; - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + return reportUnknownTransformError(target); } //===----------------------------------------------------------------------===// @@ -416,11 +412,8 @@ FailureOr VectorizeOp::applyToOne(Operation *target) { if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; - } + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return reportUnknownTransformError(target); return target; } diff --git a/mlir/lib/Dialect/SCF/CMakeLists.txt b/mlir/lib/Dialect/SCF/CMakeLists.txt index 5377286..2301a32 100644 --- a/mlir/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/CMakeLists.txt @@ -16,5 +16,6 @@ add_mlir_dialect_library(MLIRSCF MLIRSideEffectInterfaces ) +add_subdirectory(TransformOps) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..611a78f --- /dev/null +++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRSCFTransformOps + SCFTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF/TransformOps + + DEPENDS + MLIRSCFTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffine + MLIRFunc + MLIRIR + MLIRPDL + MLIRSCF + MLIRSCFTransforms + MLIRSCFUtils + MLIRTransformDialect + MLIRVector +) diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp new file mode 100644 index 0000000..213e1fd --- /dev/null +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -0,0 +1,232 @@ +//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/Patterns.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; + +namespace { +/// A simple pattern rewriter that implements no special logic. +class SimpleRewriter : public PatternRewriter { +public: + SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GetParentForOp +//===----------------------------------------------------------------------===// + +LogicalResult +transform::GetParentForOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SetVector parents; + for (Operation *target : state.getPayloadOps(getTarget())) { + scf::ForOp loop; + Operation *current = target; + for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { + loop = current->getParentOfType(); + if (!loop) { + InFlightDiagnostic diag = emitError() << "could not find an '" + << scf::ForOp::getOperationName() + << "' parent"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + current = loop; + } + parents.insert(loop); + } + results.set(getResult().cast(), parents.getArrayRef()); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopOutlineOp +//===----------------------------------------------------------------------===// + +/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses +/// the provided rewriter for all operations to remain compatible with the +/// rewriting infra, as opposed to just splicing the op in place. +static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, + Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + scf::ExecuteRegionOp executeRegionOp = + b.create(op->getLoc(), op->getResultTypes()); + { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); + Operation *clonedOp = b.cloneWithoutRegions(*op); + Region &clonedRegion = clonedOp->getRegions().front(); + assert(clonedRegion.empty() && "expected empty region"); + b.inlineRegionBefore(op->getRegions().front(), clonedRegion, + clonedRegion.end()); + b.create(op->getLoc(), clonedOp->getResults()); + } + b.replaceOp(op, executeRegionOp.getResults()); + return executeRegionOp; +} + +LogicalResult +transform::LoopOutlineOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector transformed; + DenseMap symbolTables; + for (Operation *target : state.getPayloadOps(getTarget())) { + Location location = target->getLoc(); + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); + SimpleRewriter rewriter(getContext()); + scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); + if (!exec) { + InFlightDiagnostic diag = emitError() << "failed to outline"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + func::CallOp call; + FailureOr outlined = outlineSingleBlockRegion( + rewriter, location, exec.getRegion(), getFuncName(), &call); + + if (failed(outlined)) + return reportUnknownTransformError(target); + + if (symbolTableOp) { + SymbolTable &symbolTable = + symbolTables.try_emplace(symbolTableOp, symbolTableOp) + .first->getSecond(); + symbolTable.insert(*outlined); + call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); + } + transformed.push_back(*outlined); + } + results.set(getTransformed().cast(), transformed); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopPeelOp +//===----------------------------------------------------------------------===// + +FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop) { + scf::ForOp result; + IRRewriter rewriter(loop->getContext()); + LogicalResult status = + scf::peelAndCanonicalizeForLoop(rewriter, loop, result); + if (failed(status)) { + if (getFailIfAlreadyDivisible()) + return reportUnknownTransformError(loop); + return loop; + } + return result; +} + +//===----------------------------------------------------------------------===// +// LoopPipelineOp +//===----------------------------------------------------------------------===// + +/// Callback for PipeliningOption. Populates `schedule` with the mapping from an +/// operation to its logical time position given the iteration interval and the +/// read latency. The latter is only relevant for vector transfers. +static void +loopScheduling(scf::ForOp forOp, + std::vector> &schedule, + unsigned iterationInterval, unsigned readLatency) { + auto getLatency = [&](Operation *op) -> unsigned { + if (isa(op)) + return readLatency; + return 1; + }; + + DenseMap opCycles; + std::map> wrappedSchedule; + for (Operation &op : forOp.getBody()->getOperations()) { + if (isa(op)) + continue; + unsigned earlyCycle = 0; + for (Value operand : op.getOperands()) { + Operation *def = operand.getDefiningOp(); + if (!def) + continue; + earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); + } + opCycles[&op] = earlyCycle; + wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); + } + for (auto it : wrappedSchedule) { + for (Operation *op : it.second) { + unsigned cycle = opCycles[op]; + schedule.push_back(std::make_pair(op, cycle / iterationInterval)); + } + } +} + +FailureOr transform::LoopPipelineOp::applyToOne(scf::ForOp loop) { + scf::PipeliningOption options; + options.getScheduleFn = + [this](scf::ForOp forOp, + std::vector> &schedule) mutable { + loopScheduling(forOp, schedule, getIterationInterval(), + getReadLatency()); + }; + + scf::ForLoopPipeliningPattern pattern(options, loop->getContext()); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(loop); + FailureOr patternResult = + pattern.returningMatchAndRewrite(loop, rewriter); + if (failed(patternResult)) + return reportUnknownTransformError(loop); + return patternResult; +} + +//===----------------------------------------------------------------------===// +// LoopUnrollOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) { + if (failed(loopUnrollByFactor(loop, getFactor()))) + return reportUnknownTransformError(loop); + return success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class SCFTransformDialectExtension + : public transform::TransformDialectExtension< + SCFTransformDialectExtension> { +public: + SCFTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" + +void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 659d248..cd9e036 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -12,6 +12,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/Patterns.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -436,76 +437,53 @@ void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { it->second[idx] = el; } -/// Generate a pipelined version of the scf.for loop based on the schedule given -/// as option. This applies the mechanical transformation of changing the loop -/// and generating the prologue/epilogue for the pipelining and doesn't make any -/// decision regarding the schedule. -/// Based on the option the loop is split into several stages. -/// The transformation assumes that the scheduling given by user is valid. -/// For example if we break a loop into 3 stages named S0, S1, S2 we would -/// generate the following code with the number in parenthesis the iteration -/// index: -/// S0(0) // Prologue -/// S0(1) S1(0) // Prologue -/// scf.for %I = %C0 to %N - 2 { -/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel -/// } -/// S1(N) S2(N-1) // Epilogue -/// S2(N) // Epilogue -struct ForLoopPipelining : public OpRewritePattern { - ForLoopPipelining(const PipeliningOption &options, MLIRContext *context) - : OpRewritePattern(context), options(options) {} - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override { +} // namespace - LoopPipelinerInternal pipeliner; - if (!pipeliner.initializeLoopInfo(forOp, options)) - return failure(); +FailureOr ForLoopPipeliningPattern::returningMatchAndRewrite( + ForOp forOp, PatternRewriter &rewriter) const { - // 1. Emit prologue. - pipeliner.emitPrologue(rewriter); + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); - // 2. Track values used across stages. When a value cross stages it will - // need to be passed as loop iteration arguments. - // We first collect the values that are used in a different stage than where - // they are defined. - llvm::MapVector - crossStageValues = pipeliner.analyzeCrossStageValues(); + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); - // Mapping between original loop values used cross stage and the block - // arguments associated after pipelining. A Value may map to several - // arguments if its liverange spans across more than 2 stages. - llvm::DenseMap, unsigned> loopArgMap; - // 3. Create the new kernel loop and return the block arguments mapping. - ForOp newForOp = - pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); - // Create the kernel block, order ops based on user choice and remap - // operands. - pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); - llvm::SmallVector returnValues = - newForOp.getResults().take_front(forOp->getNumResults()); - if (options.peelEpilogue) { - // 4. Emit the epilogue after the new forOp. - rewriter.setInsertionPointAfter(newForOp); - returnValues = pipeliner.emitEpilogue(rewriter); - } - // 5. Erase the original loop and replace the uses with the epilogue output. - if (forOp->getNumResults() > 0) - rewriter.replaceOp(forOp, returnValues); - else - rewriter.eraseOp(forOp); + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); - return success(); + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); -protected: - PipeliningOption options; -}; - -} // namespace + return newForOp; +} void mlir::scf::populateSCFLoopPipeliningPatterns( RewritePatternSet &patterns, const PipeliningOption &options) { - patterns.add(options, patterns.getContext()); + patterns.add(options, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 0910cb3..bce73bd 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -105,13 +105,16 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. /// This method also clones the `arith::ConstantIndexOp` at the start of -/// `outlinedFuncBody` to alloc simple canonicalizations. +/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is +/// provided, it will be set to point to the operation that calls the outlined +/// function. // TODO: support more than single-block regions. // TODO: more flexible constant handling. FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, - StringRef funcName) { + StringRef funcName, + func::CallOp *callOp) { assert(!funcName.empty() && "funcName cannot be empty"); if (!region.hasOneBlock()) return failure(); @@ -176,8 +179,9 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, SmallVector callValues; llvm::append_range(callValues, newBlock->getArguments()); llvm::append_range(callValues, outlinedValues); - Operation *call = - rewriter.create(loc, outlinedFunc, callValues); + auto call = rewriter.create(loc, outlinedFunc, callValues); + if (callOp) + *callOp = call; // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. // Clone `originalTerminator` to take the callOp results then erase it from diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 91ad5ff..9dadd26 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -137,17 +137,6 @@ LogicalResult transform::GetClosestIsolatedParentOp::apply( return success(); } -void transform::GetClosestIsolatedParentOp::getEffects( - SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Allocate::get(), getParent(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), getParent(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); -} - //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 17048e8..13b35f1 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -128,6 +128,16 @@ declare_mlir_dialect_python_bindings( declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFLoopTransformOps.td + SOURCES + dialects/_loop_transform_ops_ext.py + dialects/transform/loop.py + DIALECT_NAME transform + EXTENSION_NAME loop_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/LinalgStructuredTransformOps.td SOURCES dialects/_structured_transform_ops_ext.py diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td new file mode 100644 index 0000000..5ef07fc --- /dev/null +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -0,0 +1,21 @@ +//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the loop transform ops +// provided by the SCF (and other) dialects. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS +#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" + +#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py new file mode 100644 index 0000000..7452c42 --- /dev/null +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], + default_value: int = None): + if isinstance(arg, IntegerAttr): + return arg + + if arg is None: + assert default_value is not None, "must provide default value" + arg = default_value + + return IntegerAttr.get(IntegerType.get_signless(64), arg) + + +class GetParentForOp: + """Extension for GetParentForOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + num_loops: int = 1, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + num_loops=_get_int64_attr(num_loops, default_value=1), + ip=ip, + loc=loc) + + +class LoopOutlineOp: + """Extension for LoopOutlineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + func_name=(func_name if isinstance(func_name, StringAttr) else + StringAttr.get(func_name)), + ip=ip, + loc=loc) + + +class LoopPeelOp: + """Extension for LoopPeelOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + fail_if_already_divisible=(fail_if_already_divisible if isinstance( + fail_if_already_divisible, BoolAttr) else + BoolAttr.get(fail_if_already_divisible)), + ip=ip, + loc=loc) + + +class LoopPipelineOp: + """Extension for LoopPipelineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + iteration_interval=_get_int64_attr(iteration_interval, default_value=1), + read_latency=_get_int64_attr(read_latency, default_value=10), + ip=ip, + loc=loc) + + +class LoopUnrollOp: + """Extension for LoopUnrollOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None): + super().__init__( + _get_op_result_or_value(target), + factor=_get_int64_attr(factor), + ip=ip, + loc=loc) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py new file mode 100644 index 0000000..86f7278 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._loop_transform_ops_gen import * diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir new file mode 100644 index 0000000..deaf367 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -0,0 +1,264 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @get_parent_for_op +func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { + // expected-remark @below {{first loop}} + scf.for %i = %arg0 to %arg1 step %arg2 { + // expected-remark @below {{second loop}} + scf.for %j = %arg0 to %arg1 step %arg2 { + // expected-remark @below {{third loop}} + scf.for %k = %arg0 to %arg1 step %arg2 { + arith.addi %i, %j : index + } + } + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + // CHECK: = transform.loop.get_parent_for + %1 = transform.loop.get_parent_for %0 + %2 = transform.loop.get_parent_for %0 { num_loops = 2 } + %3 = transform.loop.get_parent_for %0 { num_loops = 3 } + transform.test_print_remark_at_operand %1, "third loop" + transform.test_print_remark_at_operand %2, "second loop" + transform.test_print_remark_at_operand %3, "first loop" + } +} + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-note @below {{target op}} + arith.addi %arg0, %arg1 : index + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + // expected-error @below {{could not find an 'scf.for' parent}} + %1 = transform.loop.get_parent_for %0 + } +} + +// ----- + +// Outlined functions: +// +// CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) +// CHECK: scf.for +// CHECK: arith.addi +// +// CHECK: func @foo[[SUFFIX:.+]](%{{.+}}, %{{.+}}, %{{.+}}) +// CHECK: scf.for +// CHECK: arith.addi +// +// CHECK-LABEL @loop_outline_op +func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) { + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: scf.execute_region + // CHECK: func.call @foo + scf.for %i = %arg0 to %arg1 step %arg2 { + scf.for %j = %arg0 to %arg1 step %arg2 { + arith.addi %i, %j : index + } + } + // CHECK: scf.execute_region + // CHECK-NOT: scf.for + // CHECK: func.call @foo[[SUFFIX]] + scf.for %j = %arg0 to %arg1 step %arg2 { + arith.addi %j, %j : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + // CHECK: = transform.loop.outline %{{.*}} + transform.loop.outline %1 {func_name = "foo"} + } +} + +// ----- + +func.func private @cond() -> i1 +func.func private @body() + +func.func @loop_outline_op_multi_region() { + // expected-note @below {{target op}} + scf.while : () -> () { + %0 = func.call @cond() : () -> i1 + scf.condition(%0) + } do { + ^bb0: + func.call @body() : () -> () + scf.yield + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_while : benefit(1) { + %args = operands + %results = types + %op = operation "scf.while"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_while in %arg1 + // expected-error @below {{failed to outline}} + transform.loop.outline %0 {func_name = "foo"} + } +} + +// ----- + +// CHECK-LABEL: @loop_peel_op +func.func @loop_peel_op() { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[C42:.+]] = arith.constant 42 + // CHECK: %[[C5:.+]] = arith.constant 5 + // CHECK: %[[C40:.+]] = arith.constant 40 + // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]] + // CHECK: arith.addi + // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]] + // CHECK: arith.addi + %0 = arith.constant 0 : index + %1 = arith.constant 42 : index + %2 = arith.constant 5 : index + scf.for %i = %0 to %1 step %2 { + arith.addi %i, %i : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + transform.loop.peel %1 + } +} + +// ----- + +func.func @loop_pipeline_op(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + // CHECK: memref.load %[[MEMREF:.+]][%{{.+}}] + // CHECK: memref.load %[[MEMREF]] + // CHECK: arith.addf + // CHECK: scf.for + // CHECK: memref.load + // CHECK: arith.addf + // CHECK: memref.store + // CHECK: arith.addf + // CHECK: memref.store + // CHECK: memref.store + // expected-remark @below {{transformed}} + scf.for %i0 = %c0 to %c4 step %c1 { + %A_elem = memref.load %A[%i0] : memref + %A1_elem = arith.addf %A_elem, %cf : f32 + memref.store %A1_elem, %result[%i0] : memref + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addf : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addf"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addf in %arg1 + %1 = transform.loop.get_parent_for %0 + %2 = transform.loop.pipeline %1 + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %2, "transformed" + } +} + +// ----- + +// CHECK-LABEL: @loop_unroll_op +func.func @loop_unroll_op() { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c5 = arith.constant 5 : index + // CHECK: scf.for %[[I:.+]] = + scf.for %i = %c0 to %c42 step %c5 { + // CHECK-COUNT-4: arith.addi %[[I]] + arith.addi %i, %i : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + transform.loop.unroll %1 { factor = 4 } + } +} + diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py new file mode 100644 index 0000000..a324266 --- /dev/null +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -0,0 +1,71 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects import pdl +from mlir.dialects.transform import loop + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def getParentLoop(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.GetParentForOp(sequence.bodyTarget, num_loops=2) + transform.YieldOp() + # CHECK-LABEL: TEST: getParentLoop + # CHECK: = transform.loop.get_parent_for % + # CHECK: num_loops = 2 + + +@run +def loopOutline(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo") + transform.YieldOp() + # CHECK-LABEL: TEST: loopOutline + # CHECK: = transform.loop.outline % + # CHECK: func_name = "foo" + + +@run +def loopPeel(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopPeelOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPeel + # CHECK: = transform.loop.peel % + + +@run +def loopPipeline(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPipeline + # CHECK: = transform.loop.pipeline % + # CHECK-DAG: iteration_interval = 3 + # CHECK-DAG: read_latency = 10 + + +@run +def loopUnroll(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopUnrollOp(sequence.bodyTarget, factor=42) + transform.YieldOp() + # CHECK-LABEL: TEST: loopUnroll + # CHECK: transform.loop.unroll % + # CHECK: factor = 42 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 516b1ec..b511d48 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1868,6 +1868,7 @@ cc_library( hdrs = [ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Passes.h", + "include/mlir/Dialect/SCF/Patterns.h", "include/mlir/Dialect/SCF/Transforms.h", ], includes = ["include"], @@ -1892,6 +1893,59 @@ cc_library( ], ) +td_library( + name = "SCFTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td", + ], + includes = ["include"], + deps = [ + ":PDLDialect", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "SCFTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td", + deps = [ + ":SCFTransformOpsTdFiles", + ], +) + +cc_library( + name = "SCFTransformOps", + srcs = glob(["lib/Dialect/SCF/TransformOps/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/SCF/TransformOps/*.h"]), + includes = ["include"], + deps = [ + ":Affine", + ":FuncDialect", + ":IR", + ":PDLDialect", + ":SCFDialect", + ":SCFTransformOpsIncGen", + ":SCFTransforms", + ":SCFUtils", + ":SideEffectInterfaces", + ":TransformDialect", + ":VectorOps", + "//llvm:Support", + ], +) + ##---------------------------------------------------------------------------## # SparseTensor dialect. ##---------------------------------------------------------------------------## @@ -2601,6 +2655,7 @@ cc_library( ], exclude = [ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/SCF/Patterns.h", "include/mlir/Dialect/SCF/Transforms.h", ], ), @@ -6299,6 +6354,7 @@ cc_library( ":SCFPassIncGen", ":SCFToGPUPass", ":SCFToStandard", + ":SCFTransformOps", ":SCFTransforms", ":SDBM", ":SPIRVDialect", diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index c94bc5d..c013492 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -878,11 +878,33 @@ gentbl_filegroup( ], ) +gentbl_filegroup( + name = "LoopTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=loop_transform", + ], + "mlir/dialects/_loop_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/SCFLoopTransformOps.td", + deps = [ + ":TransformOpsPyTdFiles", + "//mlir:SCFTransformOpsTdFiles", + ], +) + filegroup( name = "TransformOpsPyFiles", srcs = [ + "mlir/dialects/_loop_transform_ops_ext.py", "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", + ":LoopTransformOpsPyGen", ":StructuredTransformOpsPyGen", ":TransformOpsPyGen", ], -- 2.7.4