[mlir][Transform][NFC] Use a single rewriter instead of duplicating it everywhere
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 11:48:23 +0000 (03:48 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 11:54:31 +0000 (03:54 -0800)
Differential Revision: https://reviews.llvm.org/D139094

mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h [new file with mode: 0644]
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
new file mode 100644 (file)
index 0000000..512c915
--- /dev/null
@@ -0,0 +1,29 @@
+//===- TransformUtils.h - Transform Dialect Utils ----------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace transform {
+
+/// A simple pattern rewriter that can be constructed from a context. This is
+/// necessary to apply patterns to a specific op locally.
+class TrivialPatternRewriter : public PatternRewriter {
+public:
+  explicit TrivialPatternRewriter(MLIRContext *context)
+      : PatternRewriter(context) {}
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
index 605c07f..eafc6d9 100644 (file)
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 
 using namespace mlir;
 
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
-  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
index a3d0261..e035a26 100644 (file)
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Value.h"
-#include "llvm/ADT/None.h"
-#include "llvm/ADT/Optional.h"
 
 using namespace mlir;
 using namespace mlir::gpu;
 using namespace mlir::transform;
 
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
-  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
 /// Check if given mapping attributes are one of the desired attributes
 static DiagnosedSilenceableFailure
 checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
@@ -135,7 +124,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
 
 /// Alter kernel configuration of the given kernel.
 static DiagnosedSilenceableFailure
-alterGpuLaunch(SimpleRewriter &rewriter, LaunchOp gpuLaunch,
+alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
                TransformOpInterface transformOp,
                Optional<int64_t> gridDimX = llvm::None,
                Optional<int64_t> gridDimY = llvm::None,
@@ -305,7 +294,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
                                           SmallVectorImpl<Operation *> &results,
                                           transform::TransformState &state) {
   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   auto transformOp = cast<TransformOpInterface>(getOperation());
 
   if (!getGenerateGpuLaunch() && !gpuLaunch) {
@@ -555,7 +544,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
   }
 
   MLIRContext *ctx = getContext();
-  SimpleRewriter rewriter(ctx);
+  TrivialPatternRewriter rewriter(ctx);
   rewriter.setInsertionPoint(target);
 
   SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
index 53709a2..4bfa5b4 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,14 +42,6 @@ static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
   return result;
 }
 
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
-  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
 /// Attempts to apply the pattern specified as template argument to the given
 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
 /// function that returns the "main" result or failure. Returns failure if the
@@ -65,7 +58,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
 
   // Apply the pattern directly to the op.
   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
-  SimpleRewriter rewriter(operation->getContext());
+  TrivialPatternRewriter rewriter(operation->getContext());
   rewriter.setInsertionPoint(operation);
   auto result = pattern.returningMatchAndRewrite(op, rewriter);
   if (failed(result))
@@ -125,7 +118,7 @@ static LogicalResult applyTilingToAll(
     if (!tilingInterfaceOp)
       return transformOp->emitError("only TilingInterface ops are supported");
 
-    SimpleRewriter rewriter(target->getContext());
+    TrivialPatternRewriter rewriter(target->getContext());
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
         applyFn(tilingInterfaceOp);
@@ -209,7 +202,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        SimpleRewriter rewriter(getContext());
+        TrivialPatternRewriter rewriter(getContext());
         return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
             rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
@@ -601,7 +594,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
     results.push_back(target);
     return DiagnosedSilenceableFailure(success());
   }
-  SimpleRewriter rewriter(target->getContext());
+  TrivialPatternRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
       interchangeGenericOp(rewriter, target, interchangeVector);
   if (failed(res))
@@ -875,7 +868,7 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
-  SimpleRewriter rewriter(target->getContext());
+  TrivialPatternRewriter rewriter(target->getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
   if (failed(res))
@@ -974,7 +967,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
     return tileSizes;
   });
   SmallVector<int64_t> emptyTileSizes;
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
@@ -993,7 +986,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
                                            TransformState &state) {
   // Collect the dynamic split points if provided.
   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   SmallVector<OpFoldResult> splitPoints;
   splitPoints.reserve(payload.size());
   if (getDynamicSplitPoint()) {
@@ -1122,8 +1115,7 @@ void SplitOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult SplitOp::verify() {
-  if ((static_cast<int64_t>(getStaticSplitPoint()) !=
-       ShapedType::kDynamic) ^
+  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
       (getDynamicSplitPoint() == nullptr)) {
     return emitOpError() << "expects either a dynamic or a static split "
                             "point to be provided";
@@ -1172,7 +1164,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
                                          unsigned(getInsertSplitDimension()),
                                          bool(getInnerParallel())};
   };
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<SplitReductionResult> splitResult =
       (getUseScalingAlgorithm())
@@ -1195,7 +1187,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
     linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
     transform::TransformState &state) {
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
   SmallVector<OpFoldResult> sizes;
@@ -1223,7 +1215,7 @@ DiagnosedSilenceableFailure
 transform::TileReductionUsingForeachThreadOp::applyToOne(
     linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
     transform::TransformState &state) {
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
   SmallVector<OpFoldResult> numThreadResults;
@@ -1321,7 +1313,7 @@ transform::TileOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    SimpleRewriter rewriter(linalgOp.getContext());
+    TrivialPatternRewriter rewriter(linalgOp.getContext());
     FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
         rewriter, cast<TilingInterface>(linalgOp.getOperation()),
         tilingOptions);
@@ -1714,7 +1706,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    SimpleRewriter rewriter(tilingInterfaceOp.getContext());
+    TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext());
     FailureOr<scf::SCFTilingResult> tilingResult =
         tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
     if (failed(tilingResult))
index 02c18c8..8777662 100644 (file)
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 using namespace mlir;
 
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
-  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // GetParentForOp
 //===----------------------------------------------------------------------===//
@@ -97,7 +90,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Location location = target->getLoc();
     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
-    SimpleRewriter rewriter(getContext());
+    TrivialPatternRewriter rewriter(getContext());
     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
     if (!exec) {
       DiagnosedSilenceableFailure diag = emitSilenceableError()
@@ -201,7 +194,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
                        getReadLatency());
       };
   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
-  SimpleRewriter rewriter(getContext());
+  TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<scf::ForOp> patternResult =
       pattern.returningMatchAndRewrite(target, rewriter);
index 76e0c89..98ab3f7 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -33,14 +34,6 @@ using namespace mlir;
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// A simple pattern rewriter that can be constructed from a context. This is
-/// necessary to apply patterns to a specific op locally.
-class TrivialPatternRewriter : public PatternRewriter {
-public:
-  explicit TrivialPatternRewriter(MLIRContext *context)
-      : PatternRewriter(context) {}
-};
-
 /// A TransformState extension that keeps track of compiled PDL pattern sets.
 /// This is intended to be used along the WithPDLPatterns op. The extension
 /// can be constructed given an operation that has a SymbolTable trait and
@@ -109,7 +102,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
   }
 
   PatternApplicator applicator(it->second);
-  TrivialPatternRewriter rewriter(root->getContext());
+  transform::TrivialPatternRewriter rewriter(root->getContext());
   applicator.applyDefaultCostModel();
   root->walk([&](Operation *op) {
     if (succeeded(applicator.matchAndRewrite(op, rewriter)))