[mlir][linalg] Move generalization pattern to Transforms (NFC).
authorTobias Gysi <gysit@google.com>
Tue, 5 Oct 2021 12:24:19 +0000 (12:24 +0000)
committerTobias Gysi <gysit@google.com>
Tue, 5 Oct 2021 12:49:42 +0000 (12:49 +0000)
Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110728

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index 4a76b92..b2bd870 100644 (file)
@@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
 void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
                           ArrayRef<unsigned> interchangeVector);
 
+/// Creates a GenericOp from the given named operation `namedOp`. Assumes
+/// `namedOp` is not a GenericOp and has a region builder.
+GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -380,6 +384,9 @@ LogicalResult
 interchangeGenericOpPrecondition(GenericOp genericOp,
                                  ArrayRef<unsigned> interchangeVector);
 
+/// Generalize named operations to generic operations.
+LogicalResult generalizeNamedOpPrecondition(Operation *op);
+
 /// Promote std.subviews feeding linalg operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
@@ -702,6 +709,31 @@ private:
 };
 
 ///
+/// Linalg generalization pattern.
+///
+/// Apply the `generalization` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `generalization` for more details.
+struct LinalgGeneralizationPattern : public RewritePattern {
+  // Entry point to match any LinalgOp OpInterface.
+  LinalgGeneralizationPattern(
+      MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  // Entry point to match a specific Linalg op.
+  LinalgGeneralizationPattern(
+      StringRef opName, MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+};
+
+///
 /// Linalg promotion patterns.
 ///
 /// Apply the `promoteSubViews` transformation as a pattern.
index d0d14f8..afa78b8 100644 (file)
 using namespace mlir;
 using namespace mlir::linalg;
 
-// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
-// the given `namedOp` does not have a region builder.
-static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
-                                            PatternRewriter &rewriter) {
+LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+  LinalgOp namedOp = dyn_cast<LinalgOp>(op);
+  // Check if the operation is a LinalgOp but not a GenericOp.
+  if (!namedOp || isa<GenericOp>(op))
+    return failure();
+  // Check if the operation has a region builder.
+  if (!namedOp.getRegionBuilder())
+    return failure();
+  return success();
+}
+
+GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
+                                          LinalgOp namedOp) {
   SmallVector<Value> inputOperands = namedOp.getInputOperands();
   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
   // Otherwise use the region builder to generate a new region.
   // TODO: Remove this path once all linag operations have a region attached.
   auto regionBuilder = namedOp.getRegionBuilder();
-  if (!regionBuilder) {
-    LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
-    return nullptr;
-  }
+  assert(regionBuilder && "expect the operation to have region builder");
   return rewriter.create<GenericOp>(
       namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
       iterators,
@@ -112,41 +118,6 @@ struct GeneralizeConvOp
   GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
 };
 
-/// Catch-all pattern for converting all named ops with a region builder into
-/// linalg.generic.
-struct LinalgNamedOpGeneralizationPattern : RewritePattern {
-  LinalgNamedOpGeneralizationPattern(MLIRContext *context,
-                                     LinalgTransformationFilter marker,
-                                     PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-        marker(std::move(marker)) {}
-
-  LogicalResult matchAndRewrite(Operation *rootOp,
-                                PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<LinalgOp>(rootOp);
-    if (!linalgOp)
-      return failure();
-    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
-      return failure();
-
-    // No nothing to do for linalg.generic.
-    if (isa<GenericOp>(rootOp))
-      return failure();
-
-    GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
-    if (!genericOp)
-      return failure();
-
-    rewriter.replaceOp(rootOp, genericOp.getResults());
-    marker.replaceLinalgTransformationFilter(rewriter,
-                                             genericOp.getOperation());
-    return success();
-  }
-
-private:
-  LinalgTransformationFilter marker;
-};
-
 struct LinalgGeneralizationPass
     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
   void runOnFunction() override;
@@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
-  patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
-                                                   marker);
+  patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
index aacb20c..34ff401 100644 (file)
@@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
   return success();
 }
 
+/// Linalg generalization pattern.
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+    MLIRContext *context, LinalgTransformationFilter filter,
+    PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
+
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+    StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
+    PatternBenefit benefit)
+    : RewritePattern(opName, benefit, context, {}), filter(filter) {}
+
+LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+  if (failed(generalizeNamedOpPrecondition(op)))
+    return failure();
+
+  GenericOp genericOp = generalizeNamedOp(rewriter, op);
+  rewriter.replaceOp(op, genericOp.getResults());
+  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+  return success();
+}
+
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     MLIRContext *context, LinalgTransformationFilter filter,
     LinalgPromotionOptions options, PatternBenefit benefit)