[MLIR] Use memref.copy ops in BufferResultsToOutParams pass.
authorcwz920716 <cwz920716@gmail.com>
Wed, 15 Sep 2021 02:59:18 +0000 (02:59 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 15 Sep 2021 02:59:30 +0000 (02:59 +0000)
Both copy/alloc ops are using memref dialect after this change.

Reviewed By: silvas, mehdi_amini

Differential Revision: https://reviews.llvm.org/D109480

mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/BufferResultsToOutParams.cpp
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/PassDetail.h
mlir/test/Transforms/buffer-results-to-out-params.mlir

index 45d72c0..91af2a2 100644 (file)
@@ -352,7 +352,7 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     works for static shaped memrefs.
   }];
   let constructor = "mlir::createBufferResultsToOutParamsPass()";
-  let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
+  let dependentDialects = ["memref::MemRefDialect"];
 }
 
 def Canonicalizer : Pass<"canonicalize"> {
index 0920d13..73cc073 100644 (file)
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Operation.h"
@@ -71,7 +70,7 @@ static void updateReturnOps(FuncOp func,
     }
     OpBuilder builder(op);
     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
-      builder.create<linalg::CopyOp>(op.getLoc(), std::get<0>(t),
+      builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
                                      std::get<1>(t));
     builder.create<ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
index 99133af..54f3693 100644 (file)
@@ -33,7 +33,6 @@ add_mlir_library(MLIRTransforms
   MLIRAffine
   MLIRAnalysis
   MLIRCopyOpInterface
-  MLIRLinalg
   MLIRLoopLikeInterface
   MLIRMemRef
   MLIRSCF
index 0f998a7..2cb0e12 100644 (file)
@@ -18,10 +18,6 @@ class AffineDialect;
 template <typename ConcreteDialect>
 void registerDialect(DialectRegistry &registry);
 
-namespace linalg {
-class LinalgDialect;
-} // end namespace linalg
-
 namespace memref {
 class MemRefDialect;
 } // end namespace memref
index cac3e74..063d0d3 100644 (file)
@@ -3,7 +3,7 @@
 // CHECK-LABEL:   func @basic(
 // CHECK-SAME:                %[[ARG:.*]]: memref<f32>) {
 // CHECK:           %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
-// CHECK:           linalg.copy(%[[RESULT]], %[[ARG]]) : memref<f32>, memref<f32>
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]  : memref<f32> to memref<f32>
 // CHECK:           return
 // CHECK:         }
 func @basic() -> (memref<f32>) {
@@ -15,7 +15,7 @@ func @basic() -> (memref<f32>) {
 // CHECK-SAME:                                         %[[ARG0:.*]]: memref<1xf32>,
 // CHECK-SAME:                                         %[[ARG1:.*]]: memref<2xf32>) {
 // CHECK:           %[[RESULT:.*]] = "test.source"() : () -> memref<2xf32>
-// CHECK:           linalg.copy(%[[RESULT]], %[[ARG1]]) : memref<2xf32>, memref<2xf32>
+// CHECK:           memref.copy %[[RESULT]], %[[ARG1]]  : memref<2xf32> to memref<2xf32>
 // CHECK:           return
 // CHECK:         }
 func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
@@ -27,8 +27,8 @@ func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
 // CHECK-SAME:                           %[[ARG0:.*]]: memref<1xf32>,
 // CHECK-SAME:                           %[[ARG1:.*]]: memref<2xf32>) {
 // CHECK:           %[[RESULTS:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
-// CHECK:           linalg.copy(%[[RESULTS]]#0, %[[ARG0]]) : memref<1xf32>, memref<1xf32>
-// CHECK:           linalg.copy(%[[RESULTS]]#1, %[[ARG1]]) : memref<2xf32>, memref<2xf32>
+// CHECK:           memref.copy %[[RESULTS]]#0, %[[ARG0]]  : memref<1xf32> to memref<1xf32>
+// CHECK:           memref.copy %[[RESULTS]]#1, %[[ARG1]]  : memref<2xf32> to memref<2xf32>
 // CHECK:           return
 // CHECK:         }
 func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
@@ -39,7 +39,7 @@ func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
 // CHECK-LABEL:   func @non_memref_types(
 // CHECK-SAME:                           %[[OUTPARAM:.*]]: memref<f32>) -> (i1, i32) {
 // CHECK:           %[[RESULT1:.*]]:3 = "test.source"() : () -> (i1, memref<f32>, i32)
-// CHECK:           linalg.copy(%[[RESULT1]]#1, %[[OUTPARAM]]) : memref<f32>, memref<f32>
+// CHECK:           memref.copy %[[RESULT1]]#1, %[[OUTPARAM]]  : memref<f32> to memref<f32>
 // CHECK:           return %[[RESULT1]]#0, %[[RESULT1]]#2 : i1, i32
 // CHECK:         }
 func @non_memref_types() -> (i1, memref<f32>, i32) {