#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Value.h"
-#include "llvm/ADT/None.h"
-#include "llvm/ADT/Optional.h"
using namespace mlir;
using namespace mlir::gpu;
using namespace mlir::transform;
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
/// Check if given mapping attributes are one of the desired attributes
static DiagnosedSilenceableFailure
checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
/// Alter kernel configuration of the given kernel.
static DiagnosedSilenceableFailure
-alterGpuLaunch(SimpleRewriter &rewriter, LaunchOp gpuLaunch,
+alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
TransformOpInterface transformOp,
Optional<int64_t> gridDimX = llvm::None,
Optional<int64_t> gridDimY = llvm::None,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
auto transformOp = cast<TransformOpInterface>(getOperation());
if (!getGenerateGpuLaunch() && !gpuLaunch) {
}
MLIRContext *ctx = getContext();
- SimpleRewriter rewriter(ctx);
+ TrivialPatternRewriter rewriter(ctx);
rewriter.setInsertionPoint(target);
SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
return result;
}
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
// Apply the pattern directly to the op.
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
- SimpleRewriter rewriter(operation->getContext());
+ TrivialPatternRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn(tilingInterfaceOp);
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target, interchangeVector);
if (failed(res))
if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
if (failed(res))
return tileSizes;
});
SmallVector<int64_t> emptyTileSizes;
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
TransformState &state) {
// Collect the dynamic split points if provided.
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
SmallVector<OpFoldResult> splitPoints;
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
}
LogicalResult SplitOp::verify() {
- if ((static_cast<int64_t>(getStaticSplitPoint()) !=
- ShapedType::kDynamic) ^
+ if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
(getDynamicSplitPoint() == nullptr)) {
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
unsigned(getInsertSplitDimension()),
bool(getInnerParallel())};
};
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
(getUseScalingAlgorithm())
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
SmallVector<OpFoldResult> sizes;
transform::TileReductionUsingForeachThreadOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
SmallVector<OpFoldResult> numThreadResults;
}
tilingOptions.setInterchange(getInterchange());
- SimpleRewriter rewriter(linalgOp.getContext());
+ TrivialPatternRewriter rewriter(linalgOp.getContext());
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
tilingOptions);
}
tilingOptions.setInterchange(getInterchange());
- SimpleRewriter rewriter(tilingInterfaceOp.getContext());
+ TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext());
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
if (failed(tilingResult))
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//
for (Operation *target : state.getPayloadOps(getTarget())) {
Location location = target->getLoc();
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
if (!exec) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
getReadLatency());
};
scf::ForLoopPipeliningPattern pattern(options, target->getContext());
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::ForOp> patternResult =
pattern.returningMatchAndRewrite(target, rewriter);
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
//===----------------------------------------------------------------------===//
namespace {
-/// A simple pattern rewriter that can be constructed from a context. This is
-/// necessary to apply patterns to a specific op locally.
-class TrivialPatternRewriter : public PatternRewriter {
-public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
-};
-
/// A TransformState extension that keeps track of compiled PDL pattern sets.
/// This is intended to be used along the WithPDLPatterns op. The extension
/// can be constructed given an operation that has a SymbolTable trait and
}
PatternApplicator applicator(it->second);
- TrivialPatternRewriter rewriter(root->getContext());
+ transform::TrivialPatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))