From a8850312c106d71f9e35fd902f9dcd3c4ac0a690 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 1 Dec 2022 03:48:23 -0800 Subject: [PATCH] [mlir][Transform][NFC] Use a single rewriter instead of duplicating it everywhere Differential Revision: https://reviews.llvm.org/D139094 --- .../mlir/Dialect/Transform/IR/TransformUtils.h | 29 +++++++++++++++++ .../Affine/TransformOps/AffineTransformOps.cpp | 9 +----- .../Dialect/GPU/TransformOps/GPUTransformOps.cpp | 19 +++--------- .../Linalg/TransformOps/LinalgTransformOps.cpp | 36 +++++++++------------- .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 13 ++------ mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 11 ++----- 6 files changed, 53 insertions(+), 64 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h new file mode 100644 index 0000000..512c915 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h @@ -0,0 +1,29 @@ +//===- TransformUtils.h - Transform Dialect Utils ----------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace transform { + +/// A simple pattern rewriter that can be constructed from a context. This is +/// necessary to apply patterns to a specific op locally. +class TrivialPatternRewriter : public PatternRewriter { +public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} +}; + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 605c07f3..eafc6d9 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -11,17 +11,10 @@ #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" using namespace mlir; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index a3d0261..e035a26 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -16,24 +16,13 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Value.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" using namespace mlir; using namespace mlir::gpu; using namespace mlir::transform; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - /// Check if given mapping attributes are one of the desired attributes static DiagnosedSilenceableFailure checkAttributeType(ArrayRef threadMappingAttributes, @@ -135,7 +124,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc, /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(SimpleRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, Optional gridDimX = llvm::None, Optional gridDimY = llvm::None, @@ -305,7 +294,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target, SmallVectorImpl &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -555,7 +544,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne( } MLIRContext *ctx = getContext(); - SimpleRewriter rewriter(ctx); + TrivialPatternRewriter rewriter(ctx); rewriter.setInsertionPoint(target); SmallVector threadMappingAttributes = { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 53709a2..4bfa5b4 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -41,14 +42,6 @@ static SmallVector extractUIntArray(ArrayAttr attr) { return result; } -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the @@ -65,7 +58,7 @@ static FailureOr tryApply(Operation *operation, Args &&...args) { // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); - SimpleRewriter rewriter(operation->getContext()); + TrivialPatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) @@ -125,7 +118,7 @@ static LogicalResult applyTilingToAll( if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); @@ -209,7 +202,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); @@ -601,7 +594,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, results.push_back(target); return DiagnosedSilenceableFailure(success()); } - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, interchangeVector); if (failed(res)) @@ -875,7 +868,7 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target, if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -974,7 +967,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, return tileSizes; }); SmallVector emptyTileSizes; - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -993,7 +986,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -1122,8 +1115,7 @@ void SplitOp::print(OpAsmPrinter &printer) { } LogicalResult SplitOp::verify() { - if ((static_cast(getStaticSplitPoint()) != - ShapedType::kDynamic) ^ + if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; @@ -1172,7 +1164,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -1195,7 +1187,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector sizes; @@ -1223,7 +1215,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector numThreads = extractFromI64ArrayAttr(getNumThreads()); SmallVector numThreadResults; @@ -1321,7 +1313,7 @@ transform::TileOp::apply(TransformResults &transformResults, } tilingOptions.setInterchange(getInterchange()); - SimpleRewriter rewriter(linalgOp.getContext()); + TrivialPatternRewriter rewriter(linalgOp.getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(linalgOp.getOperation()), tilingOptions); @@ -1714,7 +1706,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, } tilingOptions.setInterchange(getInterchange()); - SimpleRewriter rewriter(tilingInterfaceOp.getContext()); + TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 02c18c8..8777662 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -16,18 +16,11 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// @@ -97,7 +90,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -201,7 +194,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target, getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = pattern.returningMatchAndRewrite(target, rewriter); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 76e0c89..98ab3f7 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -33,14 +34,6 @@ using namespace mlir; //===----------------------------------------------------------------------===// namespace { -/// A simple pattern rewriter that can be constructed from a context. This is -/// necessary to apply patterns to a specific op locally. -class TrivialPatternRewriter : public PatternRewriter { -public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} -}; - /// A TransformState extension that keeps track of compiled PDL pattern sets. /// This is intended to be used along the WithPDLPatterns op. The extension /// can be constructed given an operation that has a SymbolTable trait and @@ -109,7 +102,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches( } PatternApplicator applicator(it->second); - TrivialPatternRewriter rewriter(root->getContext()); + transform::TrivialPatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) -- 2.7.4