add_dependencies(mlir-headers MLIRSCFPassIncGen)
add_mlir_doc(Passes SCFPasses ./ -gen-pass-doc)
+
+add_subdirectory(TransformOps)
--- /dev/null
+//===- Patterns.h - SCF dialect rewrite patterns ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_PATTERNS_H
+#define MLIR_DIALECT_SCF_PATTERNS_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace scf {
+/// Generate a pipelined version of the scf.for loop based on the schedule given
+/// as option. This applies the mechanical transformation of changing the loop
+/// and generating the prologue/epilogue for the pipelining and doesn't make any
+/// decision regarding the schedule.
+/// Based on the options the loop is split into several stages.
+/// The transformation assumes that the scheduling given by user is valid.
+/// For example if we break a loop into 3 stages named S0, S1, S2 we would
+/// generate the following code with the number in parenthesis as the iteration
+/// index:
+/// S0(0) // Prologue
+/// S0(1) S1(0) // Prologue
+/// scf.for %I = %C0 to %N - 2 {
+/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
+/// }
+/// S1(N) S2(N-1) // Epilogue
+/// S2(N) // Epilogue
+class ForLoopPipeliningPattern : public OpRewritePattern<ForOp> {
+public:
+ ForLoopPipeliningPattern(const PipeliningOption &options,
+ MLIRContext *context)
+ : OpRewritePattern<ForOp>(context), options(options) {}
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(forOp, rewriter);
+ }
+
+ FailureOr<ForOp> returningMatchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const;
+
+protected:
+ PipeliningOption options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_PATTERNS_H
--- /dev/null
+set(LLVM_TARGET_DEFINITIONS SCFTransformOps.td)
+mlir_tablegen(SCFTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(SCFTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRSCFTransformOpsIncGen)
--- /dev/null
+//===- SCFTransformOps.h - SCF transformation ops ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
+#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace func {
+class FuncOp;
+} // namespace func
+namespace scf {
+class ForOp;
+} // namespace scf
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace scf {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
--- /dev/null
+//===- SCFTransformOps.td - SCF (loop) transformation ops --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SCF_TRANSFORM_OPS
+#define SCF_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
+ [NavigationTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Gets a handle to the parent 'for' loop of the given operation";
+ let description = [{
+ Produces a handle to the n-th (default 1) parent `scf.for` loop for each
+ Payload IR operation associated with the operand. Fails if such a loop
+ cannot be found. The list of operations associated with the handle contains
+ parent operations in the same order as the list associated with the operand,
+ except for operations that are parents to more than one input which are only
+ present once.
+ }];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<Confined<I64Attr, [IntPositive]>,
+ "1">:$num_loops);
+ let results = (outs PDL_Operation:$parent);
+
+ let assemblyFormat = "$target attr-dict";
+}
+
+def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Outlines a loop into a named function";
+ let description = [{
+ Moves the loop into a separate function with the specified name and
+ replaces the loop in the Payload IR with a call to that function. Takes
+ care of forwarding values that are used in the loop as function arguments.
+ If the operand is associated with more than one loop, each loop will be
+ outlined into a separate function. The provided name is used as a _base_
+ for forming actual function names following SymbolTable auto-renaming
+ scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
+ have a SymbolTable ancestor (typically true because of the top-level
+ module). Returns the handle to the list of outlined functions in the same
+ order as the operand handle.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ StrAttr:$func_name);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+}
+
+def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Peels the last iteration of the loop";
+ let description = [{
+ Updates the given loop so that its step evenly divides its range and puts
+ the remaining iteration into a separate loop or a conditional. Note that
+ even though the Payload IR modification may be performed in-place, this
+ operation consumes the operand handle and produces a new one. Applies to
+ each loop associated with the operand handle individually. The results
+ follow the same order as the operand.
+
+ Note: If it can be proven statically that the step already evenly divides
+ the range, this op is a no-op. In the absence of sufficient static
+ information, this op may peel a loop, even if the step always divides the
+ range evenly at runtime.
+ }];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Applies software pipelining to the loop";
+ let description = [{
+ Transforms the given loops one by one to achieve software pipelining for
+ each of them. That is, performs some amount of reads from memory before the
+ loop rather than inside the loop, the same amount of writes into memory
+ after the loop, and updates each iteration to read the data for a following
+ iteration rather than the current one. The amount is specified by the
+ attributes. The values read and about to be stored are transferred as loop
+ iteration arguments. Currently supports memref and vector transfer
+ operations as memory reads/writes.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
+ DefaultValuedAttr<I64Attr, "10">:$read_latency);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Unrolls the given loop with the given unroll factor";
+ let description = [{
+ Unrolls each loop associated with the given handle to have up to the given
+ number of loop body copies per iteration. If the unroll factor is larger
+ than the loop trip count, the latter is used as the unroll factor instead.
+ Does not produce a new handle as the operation may result in the loop being
+ removed after a full unrolling.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ Confined<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+#endif // SCF_TRANSFORM_OPS
// TODO: add option to decide if the prologue should be peeled.
};
-/// Populate patterns for SCF software pipelining transformation.
-/// This transformation generates the pipelined loop and doesn't do any
-/// assumptions on the schedule dictated by the option structure.
-/// Software pipelining is usually done in two part. The first part of
-/// pipelining is to schedule the loop and assign a stage and cycle to each
-/// operations. This is highly dependent on the target and is implemented as an
-/// heuristic based on operation latencies, and other hardware characteristics.
-/// The second part is to take the schedule and generate the pipelined loop as
-/// well as the prologue and epilogue. It is independent of the target.
-/// This pattern only implement the second part.
-/// For example if we break a loop into 3 stages named S0, S1, S2 we would
-/// generate the following code with the number in parenthesis the iteration
-/// index:
-/// S0(0) // Prologue
-/// S0(1) S1(0) // Prologue
-/// scf.for %I = %C0 to %N - 2 {
-/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
-/// }
-/// S1(N) S2(N-1) // Epilogue
-/// S2(N) // Epilogue
+/// Populate patterns for SCF software pipelining transformation. See the
+/// ForLoopPipeliningPattern for the transformation details.
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
const PipeliningOption &options);
class Value;
namespace func {
+class CallOp;
class FuncOp;
} // namespace func
/// `outlinedFuncBody` to alloc simple canonicalizations.
/// Creates a new FuncOp and thus cannot be used in a FuncOp pass.
/// The client is responsible for providing a unique `funcName` that will not
-/// collide with another FuncOp name.
+/// collide with another FuncOp name. If `callOp` is provided, it will be set
+/// to point to the operation that calls the outlined function.
// TODO: support more than single-block regions.
// TODO: more flexible constant handling.
-FailureOr<func::FuncOp> outlineSingleBlockRegion(RewriterBase &rewriter,
- Location loc, Region ®ion,
- StringRef funcName);
+FailureOr<func::FuncOp>
+outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion,
+ StringRef funcName, func::CallOp *callOp = nullptr);
/// Outline the then and/or else regions of `ifOp` as follows:
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
StringRef getName() override { return "transform.payload_ir"; }
};
-/// Trait implementing the MemoryEffectOpInterface for single-operand
+/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or
/// single-result operations that "consume" their operand and produce a new
/// result.
template <typename OpTy>
effects.emplace_back(MemoryEffects::Free::get(),
this->getOperation()->getOperand(0),
TransformMappingResource::get());
+ if (this->getOperation()->getNumResults() == 1) {
+ effects.emplace_back(MemoryEffects::Allocate::get(),
+ this->getOperation()->getResult(0),
+ TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ this->getOperation()->getResult(0),
+ TransformMappingResource::get());
+ }
+ effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+ }
+
+ /// Checks that the op matches the expectations of this trait.
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+ "expected single-operand op");
+ static_assert(OpTy::template hasTrait<OpTrait::ZeroResults>() ||
+ OpTy::template hasTrait<OpTrait::OneResult>(),
+ "expected zero- or single-result op");
+ if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
+ op->emitError()
+ << "FunctionalStyleTransformOpTrait should only be attached to ops "
+ "that implement MemoryEffectOpInterface";
+ }
+ return success();
+ }
+};
+
+/// Trait implementing the MemoryEffectOpInterface for single-operand
+/// single-result operations that use their operand without consuming and
+/// without modifying the Payload IR to produce a new handle.
+template <typename OpTy>
+class NavigationTransformOpTrait
+ : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
+public:
+ /// This op produces handles to the Payload IR without consuming the original
+ /// handles and without modifying the IR itself.
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ this->getOperation()->getOperand(0),
+ TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(),
this->getOperation()->getResult(0),
TransformMappingResource::get());
this->getOperation()->getResult(0),
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}
- /// Checks that the op matches the expectations of this trait.
+ /// Checks that the op matches the expectation of this trait.
static LogicalResult verifyTrait(Operation *op) {
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
"expected single-operand op");
static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
"expected single-result op");
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
- op->emitError()
- << "FunctionalStyleTransformOpTrait should only be attached to ops "
- "that implement MemoryEffectOpInterface";
+ op->emitError() << "NavigationTransformOpTrait should only be attached "
+ "to ops that implement MemoryEffectOpInterface";
}
return success();
}
"::mlir::transform::TransformState &":$state
)>,
];
+
+ let extraSharedClassDeclaration = [{
+ /// Emits a generic transform error for the current transform operation
+ /// targeting the given Payload IR operation and returns failure. Should
+ /// be only used as a last resort when the transformation itself provides
+ /// no further indication as to the reason of the failure.
+ ::mlir::LogicalResult reportUnknownTransformError(
+ ::mlir::Operation *target) {
+ ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply";
+ diag.attachNote(target->getLoc()) << "attempted to apply to this op";
+ return diag;
+ }
+ }];
}
def FunctionalStyleTransformOpTrait
let cppNamespace = "::mlir::transform";
}
+def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
[DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
let summary = "Gets handles to the closest isolated-from-above parents";
let description = [{
The handles defined by this Transform op correspond to the closest isolated
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
// Register all dialect extensions.
linalg::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
// Register all external models.
arith::registerBufferizableOpInterfaceExternalModels(registry);
if (succeeded(depthwise))
return depthwise;
- InFlightDiagnostic diag = emitError() << "failed to apply";
- diag.attachNote(target.getLoc()) << "attempted to apply to this op";
- return diag;
+ return reportUnknownTransformError(target);
}
//===----------------------------------------------------------------------===//
if (succeeded(generic))
return generic;
- InFlightDiagnostic diag = emitError() << "failed to apply";
- diag.attachNote(target.getLoc()) << "attempted to apply to this op";
- return diag;
+ return reportUnknownTransformError(target);
}
//===----------------------------------------------------------------------===//
if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
- InFlightDiagnostic diag = emitError() << "failed to apply";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
+ if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
+ return reportUnknownTransformError(target);
return target;
}
MLIRSideEffectInterfaces
)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
add_subdirectory(Utils)
--- /dev/null
+add_mlir_dialect_library(MLIRSCFTransformOps
+ SCFTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF/TransformOps
+
+ DEPENDS
+ MLIRSCFTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffine
+ MLIRFunc
+ MLIRIR
+ MLIRPDL
+ MLIRSCF
+ MLIRSCFTransforms
+ MLIRSCFUtils
+ MLIRTransformDialect
+ MLIRVector
+)
--- /dev/null
+//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Patterns.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+
+namespace {
+/// A simple pattern rewriter that implements no special logic.
+class SimpleRewriter : public PatternRewriter {
+public:
+ SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GetParentForOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+transform::GetParentForOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SetVector<Operation *> parents;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ scf::ForOp loop;
+ Operation *current = target;
+ for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
+ loop = current->getParentOfType<scf::ForOp>();
+ if (!loop) {
+ InFlightDiagnostic diag = emitError() << "could not find an '"
+ << scf::ForOp::getOperationName()
+ << "' parent";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ current = loop;
+ }
+ parents.insert(loop);
+ }
+ results.set(getResult().cast<OpResult>(), parents.getArrayRef());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOutlineOp
+//===----------------------------------------------------------------------===//
+
+/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
+/// the provided rewriter for all operations to remain compatible with the
+/// rewriting infra, as opposed to just splicing the op in place.
+static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
+ Operation *op) {
+ if (op->getNumRegions() != 1)
+ return nullptr;
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ scf::ExecuteRegionOp executeRegionOp =
+ b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
+ {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
+ Operation *clonedOp = b.cloneWithoutRegions(*op);
+ Region &clonedRegion = clonedOp->getRegions().front();
+ assert(clonedRegion.empty() && "expected empty region");
+ b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
+ clonedRegion.end());
+ b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
+ }
+ b.replaceOp(op, executeRegionOp.getResults());
+ return executeRegionOp;
+}
+
+LogicalResult
+transform::LoopOutlineOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> transformed;
+ DenseMap<Operation *, SymbolTable> symbolTables;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ Location location = target->getLoc();
+ Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
+ SimpleRewriter rewriter(getContext());
+ scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
+ if (!exec) {
+ InFlightDiagnostic diag = emitError() << "failed to outline";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ func::CallOp call;
+ FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
+ rewriter, location, exec.getRegion(), getFuncName(), &call);
+
+ if (failed(outlined))
+ return reportUnknownTransformError(target);
+
+ if (symbolTableOp) {
+ SymbolTable &symbolTable =
+ symbolTables.try_emplace(symbolTableOp, symbolTableOp)
+ .first->getSecond();
+ symbolTable.insert(*outlined);
+ call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
+ }
+ transformed.push_back(*outlined);
+ }
+ results.set(getTransformed().cast<OpResult>(), transformed);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopPeelOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
+ scf::ForOp result;
+ IRRewriter rewriter(loop->getContext());
+ LogicalResult status =
+ scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
+ if (failed(status)) {
+ if (getFailIfAlreadyDivisible())
+ return reportUnknownTransformError(loop);
+ return loop;
+ }
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
+// LoopPipelineOp
+//===----------------------------------------------------------------------===//
+
+/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
+/// operation to its logical time position given the iteration interval and the
+/// read latency. The latter is only relevant for vector transfers.
+static void
+loopScheduling(scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule,
+ unsigned iterationInterval, unsigned readLatency) {
+ auto getLatency = [&](Operation *op) -> unsigned {
+ if (isa<vector::TransferReadOp>(op))
+ return readLatency;
+ return 1;
+ };
+
+ DenseMap<Operation *, unsigned> opCycles;
+ std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
+ for (Operation &op : forOp.getBody()->getOperations()) {
+ if (isa<scf::YieldOp>(op))
+ continue;
+ unsigned earlyCycle = 0;
+ for (Value operand : op.getOperands()) {
+ Operation *def = operand.getDefiningOp();
+ if (!def)
+ continue;
+ earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
+ }
+ opCycles[&op] = earlyCycle;
+ wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
+ }
+ for (auto it : wrappedSchedule) {
+ for (Operation *op : it.second) {
+ unsigned cycle = opCycles[op];
+ schedule.push_back(std::make_pair(op, cycle / iterationInterval));
+ }
+ }
+}
+
+FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
+ scf::PipeliningOption options;
+ options.getScheduleFn =
+ [this](scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
+ loopScheduling(forOp, schedule, getIterationInterval(),
+ getReadLatency());
+ };
+
+ scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(loop);
+ FailureOr<scf::ForOp> patternResult =
+ pattern.returningMatchAndRewrite(loop, rewriter);
+ if (failed(patternResult))
+ return reportUnknownTransformError(loop);
+ return patternResult;
+}
+
+//===----------------------------------------------------------------------===//
+// LoopUnrollOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
+ if (failed(loopUnrollByFactor(loop, getFactor())))
+ return reportUnknownTransformError(loop);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class SCFTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ SCFTransformDialectExtension> {
+public:
+ SCFTransformDialectExtension() {
+ declareDependentDialect<AffineDialect>();
+ declareDependentDialect<func::FuncDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
+
+void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<SCFTransformDialectExtension>();
+}
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/SCF/Patterns.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
it->second[idx] = el;
}
-/// Generate a pipelined version of the scf.for loop based on the schedule given
-/// as option. This applies the mechanical transformation of changing the loop
-/// and generating the prologue/epilogue for the pipelining and doesn't make any
-/// decision regarding the schedule.
-/// Based on the option the loop is split into several stages.
-/// The transformation assumes that the scheduling given by user is valid.
-/// For example if we break a loop into 3 stages named S0, S1, S2 we would
-/// generate the following code with the number in parenthesis the iteration
-/// index:
-/// S0(0) // Prologue
-/// S0(1) S1(0) // Prologue
-/// scf.for %I = %C0 to %N - 2 {
-/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
-/// }
-/// S1(N) S2(N-1) // Epilogue
-/// S2(N) // Epilogue
-struct ForLoopPipelining : public OpRewritePattern<ForOp> {
- ForLoopPipelining(const PipeliningOption &options, MLIRContext *context)
- : OpRewritePattern<ForOp>(context), options(options) {}
- LogicalResult matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const override {
+} // namespace
- LoopPipelinerInternal pipeliner;
- if (!pipeliner.initializeLoopInfo(forOp, options))
- return failure();
+FailureOr<ForOp> ForLoopPipeliningPattern::returningMatchAndRewrite(
+ ForOp forOp, PatternRewriter &rewriter) const {
- // 1. Emit prologue.
- pipeliner.emitPrologue(rewriter);
+ LoopPipelinerInternal pipeliner;
+ if (!pipeliner.initializeLoopInfo(forOp, options))
+ return failure();
- // 2. Track values used across stages. When a value cross stages it will
- // need to be passed as loop iteration arguments.
- // We first collect the values that are used in a different stage than where
- // they are defined.
- llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
- crossStageValues = pipeliner.analyzeCrossStageValues();
+ // 1. Emit prologue.
+ pipeliner.emitPrologue(rewriter);
- // Mapping between original loop values used cross stage and the block
- // arguments associated after pipelining. A Value may map to several
- // arguments if its liverange spans across more than 2 stages.
- llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
- // 3. Create the new kernel loop and return the block arguments mapping.
- ForOp newForOp =
- pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
- // Create the kernel block, order ops based on user choice and remap
- // operands.
- pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
+ // 2. Track values used across stages. When a value cross stages it will
+ // need to be passed as loop iteration arguments.
+ // We first collect the values that are used in a different stage than where
+ // they are defined.
+ llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+ crossStageValues = pipeliner.analyzeCrossStageValues();
- llvm::SmallVector<Value> returnValues =
- newForOp.getResults().take_front(forOp->getNumResults());
- if (options.peelEpilogue) {
- // 4. Emit the epilogue after the new forOp.
- rewriter.setInsertionPointAfter(newForOp);
- returnValues = pipeliner.emitEpilogue(rewriter);
- }
- // 5. Erase the original loop and replace the uses with the epilogue output.
- if (forOp->getNumResults() > 0)
- rewriter.replaceOp(forOp, returnValues);
- else
- rewriter.eraseOp(forOp);
+ // Mapping between original loop values used cross stage and the block
+ // arguments associated after pipelining. A Value may map to several
+ // arguments if its liverange spans across more than 2 stages.
+ llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
+ // 3. Create the new kernel loop and return the block arguments mapping.
+ ForOp newForOp =
+ pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
+ // Create the kernel block, order ops based on user choice and remap
+ // operands.
+ pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
- return success();
+ llvm::SmallVector<Value> returnValues =
+ newForOp.getResults().take_front(forOp->getNumResults());
+ if (options.peelEpilogue) {
+ // 4. Emit the epilogue after the new forOp.
+ rewriter.setInsertionPointAfter(newForOp);
+ returnValues = pipeliner.emitEpilogue(rewriter);
}
+ // 5. Erase the original loop and replace the uses with the epilogue output.
+ if (forOp->getNumResults() > 0)
+ rewriter.replaceOp(forOp, returnValues);
+ else
+ rewriter.eraseOp(forOp);
-protected:
- PipeliningOption options;
-};
-
-} // namespace
+ return newForOp;
+}
void mlir::scf::populateSCFLoopPipeliningPatterns(
RewritePatternSet &patterns, const PipeliningOption &options) {
- patterns.add<ForLoopPipelining>(options, patterns.getContext());
+ patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext());
}
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
/// This method also clones the `arith::ConstantIndexOp` at the start of
-/// `outlinedFuncBody` to alloc simple canonicalizations.
+/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
+/// provided, it will be set to point to the operation that calls the outlined
+/// function.
// TODO: support more than single-block regions.
// TODO: more flexible constant handling.
FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
Location loc,
Region ®ion,
- StringRef funcName) {
+ StringRef funcName,
+ func::CallOp *callOp) {
assert(!funcName.empty() && "funcName cannot be empty");
if (!region.hasOneBlock())
return failure();
SmallVector<Value> callValues;
llvm::append_range(callValues, newBlock->getArguments());
llvm::append_range(callValues, outlinedValues);
- Operation *call =
- rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
+ auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
+ if (callOp)
+ *callOp = call;
// `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
// Clone `originalTerminator` to take the callOp results then erase it from
return success();
}
-void transform::GetClosestIsolatedParentOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
- TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Allocate::get(), getParent(),
- TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), getParent(),
- TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
-}
-
//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/SCFLoopTransformOps.td
+ SOURCES
+ dialects/_loop_transform_ops_ext.py
+ dialects/transform/loop.py
+ DIALECT_NAME transform
+ EXTENSION_NAME loop_transform)
+
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
--- /dev/null
+//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the loop transform ops
+// provided by the SCF (and other) dialects.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
+#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td"
+
+#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS
--- /dev/null
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ..dialects import pdl
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Union
+
+
+def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]],
+ default_value: int = None):
+ if isinstance(arg, IntegerAttr):
+ return arg
+
+ if arg is None:
+ assert default_value is not None, "must provide default value"
+ arg = default_value
+
+ return IntegerAttr.get(IntegerType.get_signless(64), arg)
+
+
+class GetParentForOp:
+ """Extension for GetParentForOp."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ num_loops: int = 1,
+ ip=None,
+ loc=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ num_loops=_get_int64_attr(num_loops, default_value=1),
+ ip=ip,
+ loc=loc)
+
+
+class LoopOutlineOp:
+ """Extension for LoopOutlineOp."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ func_name: Union[str, StringAttr],
+ ip=None,
+ loc=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ func_name=(func_name if isinstance(func_name, StringAttr) else
+ StringAttr.get(func_name)),
+ ip=ip,
+ loc=loc)
+
+
+class LoopPeelOp:
+ """Extension for LoopPeelOp."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ fail_if_already_divisible: Union[bool, BoolAttr] = False,
+ ip=None,
+ loc=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ fail_if_already_divisible=(fail_if_already_divisible if isinstance(
+ fail_if_already_divisible, BoolAttr) else
+ BoolAttr.get(fail_if_already_divisible)),
+ ip=ip,
+ loc=loc)
+
+
+class LoopPipelineOp:
+ """Extension for LoopPipelineOp."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+ read_latency: Optional[Union[int, IntegerAttr]] = None,
+ ip=None,
+ loc=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
+ read_latency=_get_int64_attr(read_latency, default_value=10),
+ ip=ip,
+ loc=loc)
+
+
+class LoopUnrollOp:
+ """Extension for LoopUnrollOp."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ factor: Union[int, IntegerAttr],
+ ip=None,
+ loc=None):
+ super().__init__(
+ _get_op_result_or_value(target),
+ factor=_get_int64_attr(factor),
+ ip=ip,
+ loc=loc)
--- /dev/null
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._loop_transform_ops_gen import *
--- /dev/null
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @get_parent_for_op
+func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
+ // expected-remark @below {{first loop}}
+ scf.for %i = %arg0 to %arg1 step %arg2 {
+ // expected-remark @below {{second loop}}
+ scf.for %j = %arg0 to %arg1 step %arg2 {
+ // expected-remark @below {{third loop}}
+ scf.for %k = %arg0 to %arg1 step %arg2 {
+ arith.addi %i, %j : index
+ }
+ }
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addi : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addi in %arg1
+ // CHECK: = transform.loop.get_parent_for
+ %1 = transform.loop.get_parent_for %0
+ %2 = transform.loop.get_parent_for %0 { num_loops = 2 }
+ %3 = transform.loop.get_parent_for %0 { num_loops = 3 }
+ transform.test_print_remark_at_operand %1, "third loop"
+ transform.test_print_remark_at_operand %2, "second loop"
+ transform.test_print_remark_at_operand %3, "first loop"
+ }
+}
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+ // expected-note @below {{target op}}
+ arith.addi %arg0, %arg1 : index
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addi : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addi in %arg1
+ // expected-error @below {{could not find an 'scf.for' parent}}
+ %1 = transform.loop.get_parent_for %0
+ }
+}
+
+// -----
+
+// Outlined functions:
+//
+// CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}})
+// CHECK: scf.for
+// CHECK: arith.addi
+//
+// CHECK: func @foo[[SUFFIX:.+]](%{{.+}}, %{{.+}}, %{{.+}})
+// CHECK: scf.for
+// CHECK: arith.addi
+//
+// CHECK-LABEL @loop_outline_op
+func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
+ // CHECK: scf.for
+ // CHECK-NOT: scf.for
+ // CHECK: scf.execute_region
+ // CHECK: func.call @foo
+ scf.for %i = %arg0 to %arg1 step %arg2 {
+ scf.for %j = %arg0 to %arg1 step %arg2 {
+ arith.addi %i, %j : index
+ }
+ }
+ // CHECK: scf.execute_region
+ // CHECK-NOT: scf.for
+ // CHECK: func.call @foo[[SUFFIX]]
+ scf.for %j = %arg0 to %arg1 step %arg2 {
+ arith.addi %j, %j : index
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addi : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addi in %arg1
+ %1 = transform.loop.get_parent_for %0
+ // CHECK: = transform.loop.outline %{{.*}}
+ transform.loop.outline %1 {func_name = "foo"}
+ }
+}
+
+// -----
+
+func.func private @cond() -> i1
+func.func private @body()
+
+func.func @loop_outline_op_multi_region() {
+ // expected-note @below {{target op}}
+ scf.while : () -> () {
+ %0 = func.call @cond() : () -> i1
+ scf.condition(%0)
+ } do {
+ ^bb0:
+ func.call @body() : () -> ()
+ scf.yield
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_while : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "scf.while"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_while in %arg1
+ // expected-error @below {{failed to outline}}
+ transform.loop.outline %0 {func_name = "foo"}
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_peel_op
+func.func @loop_peel_op() {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[C42:.+]] = arith.constant 42
+ // CHECK: %[[C5:.+]] = arith.constant 5
+ // CHECK: %[[C40:.+]] = arith.constant 40
+ // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]]
+ // CHECK: arith.addi
+ // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]]
+ // CHECK: arith.addi
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 42 : index
+ %2 = arith.constant 5 : index
+ scf.for %i = %0 to %1 step %2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addi : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addi in %arg1
+ %1 = transform.loop.get_parent_for %0
+ transform.loop.peel %1
+ }
+}
+
+// -----
+
+func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+ // CHECK: memref.load %[[MEMREF:.+]][%{{.+}}]
+ // CHECK: memref.load %[[MEMREF]]
+ // CHECK: arith.addf
+ // CHECK: scf.for
+ // CHECK: memref.load
+ // CHECK: arith.addf
+ // CHECK: memref.store
+ // CHECK: arith.addf
+ // CHECK: memref.store
+ // CHECK: memref.store
+ // expected-remark @below {{transformed}}
+ scf.for %i0 = %c0 to %c4 step %c1 {
+ %A_elem = memref.load %A[%i0] : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %cf : f32
+ memref.store %A1_elem, %result[%i0] : memref<?xf32>
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addf : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addf in %arg1
+ %1 = transform.loop.get_parent_for %0
+ %2 = transform.loop.pipeline %1
+ // Verify that the returned handle is usable.
+ transform.test_print_remark_at_operand %2, "transformed"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_op
+func.func @loop_unroll_op() {
+ %c0 = arith.constant 0 : index
+ %c42 = arith.constant 42 : index
+ %c5 = arith.constant 5 : index
+ // CHECK: scf.for %[[I:.+]] =
+ scf.for %i = %c0 to %c42 step %c5 {
+ // CHECK-COUNT-4: arith.addi %[[I]]
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_addi : benefit(1) {
+ %args = operands
+ %results = types
+ %op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ rewrite %op with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_addi in %arg1
+ %1 = transform.loop.get_parent_for %0
+ transform.loop.unroll %1 { factor = 4 }
+ }
+}
+
--- /dev/null
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects import pdl
+from mlir.dialects.transform import loop
+
+
+def run(f):
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ print("\nTEST:", f.__name__)
+ f()
+ print(module)
+ return f
+
+
+@run
+def getParentLoop():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: getParentLoop
+ # CHECK: = transform.loop.get_parent_for %
+ # CHECK: num_loops = 2
+
+
+@run
+def loopOutline():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopOutline
+ # CHECK: = transform.loop.outline %
+ # CHECK: func_name = "foo"
+
+
+@run
+def loopPeel():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ loop.LoopPeelOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopPeel
+ # CHECK: = transform.loop.peel %
+
+
+@run
+def loopPipeline():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopPipeline
+ # CHECK: = transform.loop.pipeline %
+ # CHECK-DAG: iteration_interval = 3
+ # CHECK-DAG: read_latency = 10
+
+
+@run
+def loopUnroll():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopUnroll
+ # CHECK: transform.loop.unroll %
+ # CHECK: factor = 42
hdrs = [
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/SCF/Passes.h",
+ "include/mlir/Dialect/SCF/Patterns.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
includes = ["include"],
],
)
+td_library(
+ name = "SCFTransformOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":PDLDialect",
+ ":TransformDialectTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "SCFTransformOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td",
+ deps = [
+ ":SCFTransformOpsTdFiles",
+ ],
+)
+
+cc_library(
+ name = "SCFTransformOps",
+ srcs = glob(["lib/Dialect/SCF/TransformOps/*.cpp"]),
+ hdrs = glob(["include/mlir/Dialect/SCF/TransformOps/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":Affine",
+ ":FuncDialect",
+ ":IR",
+ ":PDLDialect",
+ ":SCFDialect",
+ ":SCFTransformOpsIncGen",
+ ":SCFTransforms",
+ ":SCFUtils",
+ ":SideEffectInterfaces",
+ ":TransformDialect",
+ ":VectorOps",
+ "//llvm:Support",
+ ],
+)
+
##---------------------------------------------------------------------------##
# SparseTensor dialect.
##---------------------------------------------------------------------------##
],
exclude = [
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
+ "include/mlir/Dialect/SCF/Patterns.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
),
":SCFPassIncGen",
":SCFToGPUPass",
":SCFToStandard",
+ ":SCFTransformOps",
":SCFTransforms",
":SDBM",
":SPIRVDialect",
],
)
+gentbl_filegroup(
+ name = "LoopTransformOpsPyGen",
+ tbl_outs = [
+ (
+ [
+ "-gen-python-op-bindings",
+ "-bind-dialect=transform",
+ "-dialect-extension=loop_transform",
+ ],
+ "mlir/dialects/_loop_transform_ops_gen.py",
+ ),
+ ],
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "mlir/dialects/SCFLoopTransformOps.td",
+ deps = [
+ ":TransformOpsPyTdFiles",
+ "//mlir:SCFTransformOpsTdFiles",
+ ],
+)
+
filegroup(
name = "TransformOpsPyFiles",
srcs = [
+ "mlir/dialects/_loop_transform_ops_ext.py",
"mlir/dialects/_structured_transform_ops_ext.py",
"mlir/dialects/_transform_ops_ext.py",
+ ":LoopTransformOpsPyGen",
":StructuredTransformOpsPyGen",
":TransformOpsPyGen",
],