[mlir][Linalg] Relax restriction of Linalg passes on FuncOp
authorNicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Wed, 5 Jul 2023 11:49:18 +0000 (11:49 +0000)
committerNicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Wed, 5 Jul 2023 19:55:01 +0000 (19:55 +0000)
Existing Linalg passes are still anchoring on FuncOp.
Relax this unnecessary limitation.

Differential Revision: https://reviews.llvm.org/D154497

mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

index f2b5788..5f46aff 100644 (file)
@@ -38,31 +38,28 @@ std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 
 std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgInlineScalarOperandsPass();
+std::unique_ptr<Pass> createLinalgInlineScalarOperandsPass();
 
 /// Create a pass to convert Linalg operations to scf.for loops and
 /// memref.load/memref.store accesses.
-std::unique_ptr<OperationPass<func::FuncOp>> createConvertLinalgToLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToLoopsPass();
 
 /// Create a pass to convert Linalg operations to scf.parallel loops and
 /// memref.load/memref.store accesses.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createConvertLinalgToParallelLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToParallelLoopsPass();
 
 /// Create a pass to convert Linalg operations to affine.for loops and
 /// affine_load/affine_store accesses.
 /// Placeholder for now, this is NYI.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createConvertLinalgToAffineLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToAffineLoopsPass();
 
 /// Create a pass to convert Linalg operations which work on tensors to use
 /// buffers instead.
-std::unique_ptr<OperationPass<func::FuncOp>> createLinalgBufferizePass();
+std::unique_ptr<Pass> createLinalgBufferizePass();
 
 /// Create a pass to convert named Linalg operations to Linalg generic
 /// operations.
-std::unique_ptr<OperationPass<func::FuncOp>> createLinalgGeneralizationPass();
+std::unique_ptr<Pass> createLinalgGeneralizationPass();
 
 /// Create a pass to convert Linalg operations to equivalent operations that
 /// work on primitive types, if possible.
index be4c870..1ed867b 100644 (file)
@@ -55,7 +55,7 @@ def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
 }
 
-def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands", "func::FuncOp"> {
+def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands"> {
   let summary = "Inline scalar operands into linalg generic ops";
   let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
   let dependentDialects = [
@@ -63,7 +63,7 @@ def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands", "func::Fu
   ];
 }
 
-def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops", "func::FuncOp"> {
+def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";
   let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
@@ -71,7 +71,7 @@ def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops", "func::Fun
     "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
 }
 
-def LinalgLowerToLoops : Pass<"convert-linalg-to-loops", "func::FuncOp"> {
+def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
   let summary = "Lower the operations from the linalg dialect into loops";
   let constructor = "mlir::createConvertLinalgToLoopsPass()";
   let dependentDialects = [
@@ -82,7 +82,7 @@ def LinalgLowerToLoops : Pass<"convert-linalg-to-loops", "func::FuncOp"> {
 }
 
 def LinalgLowerToParallelLoops
-    : Pass<"convert-linalg-to-parallel-loops", "func::FuncOp"> {
+    : Pass<"convert-linalg-to-parallel-loops"> {
   let summary = "Lower the operations from the linalg dialect into parallel "
                 "loops";
   let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
@@ -94,7 +94,7 @@ def LinalgLowerToParallelLoops
   ];
 }
 
-def LinalgBufferize : Pass<"linalg-bufferize", "func::FuncOp"> {
+def LinalgBufferize : Pass<"linalg-bufferize"> {
   let summary = "Bufferize the linalg dialect";
   let constructor = "mlir::createLinalgBufferizePass()";
   let dependentDialects = [
@@ -105,7 +105,7 @@ def LinalgBufferize : Pass<"linalg-bufferize", "func::FuncOp"> {
   ];
 }
 
-def LinalgGeneralization : Pass<"linalg-generalize-named-ops", "func::FuncOp"> {
+def LinalgGeneralization : Pass<"linalg-generalize-named-ops"> {
   let summary = "Convert named ops into generic ops";
   let constructor = "mlir::createLinalgGeneralizationPass()";
   let dependentDialects = ["linalg::LinalgDialect"];
index ca74a17..73c4d47 100644 (file)
@@ -49,6 +49,6 @@ struct LinalgBufferizePass
 };
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLinalgBufferizePass() {
+std::unique_ptr<Pass> mlir::createLinalgBufferizePass() {
   return std::make_unique<LinalgBufferizePass>();
 }
index 146217d..2382903 100644 (file)
@@ -84,10 +84,9 @@ struct LinalgGeneralizationPass
 } // namespace
 
 void LinalgGeneralizationPass::runOnOperation() {
-  func::FuncOp func = getOperation();
   RewritePatternSet patterns(&getContext());
   populateLinalgNamedOpsGeneralizationPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
@@ -95,7 +94,6 @@ void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
   patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgGeneralizationPass() {
+std::unique_ptr<Pass> mlir::createLinalgGeneralizationPass() {
   return std::make_unique<LinalgGeneralizationPass>();
 }
index c9b118b..9cf8f7c 100644 (file)
@@ -103,17 +103,15 @@ struct LinalgInlineScalarOperandsPass
     : public impl::LinalgInlineScalarOperandsBase<
           LinalgInlineScalarOperandsPass> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
-    MLIRContext *context = funcOp.getContext();
-    RewritePatternSet patterns(context);
-
+    Operation *op = getOperation();
+    MLIRContext &ctx = getContext();
+    RewritePatternSet patterns(&ctx);
     populateInlineConstantOperandsPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgInlineScalarOperandsPass() {
+std::unique_ptr<Pass> mlir::createLinalgInlineScalarOperandsPass() {
   return std::make_unique<LinalgInlineScalarOperandsPass>();
 }
index d91d8c4..5a44a85 100644 (file)
@@ -312,8 +312,8 @@ struct FoldAffineOp : public RewritePattern {
 };
 
 template <typename LoopType>
-static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
-  MLIRContext *context = funcOp.getContext();
+static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
+  MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
@@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
   // Just apply the patterns greedily.
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
 }
 
 struct LowerToAffineLoops
@@ -352,18 +352,15 @@ struct LowerToParallelLoops
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToLoopsPass() {
   return std::make_unique<LowerToLoops>();
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToParallelLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToParallelLoopsPass() {
   return std::make_unique<LowerToParallelLoops>();
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToAffineLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToAffineLoopsPass() {
   return std::make_unique<LowerToAffineLoops>();
 }