[mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy
authorMatthias Springer <me@m-sp.org>
Fri, 14 Jul 2023 09:57:54 +0000 (11:57 +0200)
committerMatthias Springer <me@m-sp.org>
Fri, 14 Jul 2023 11:34:22 +0000 (13:34 +0200)
This new option allows users to specify a custom memcpy op.

Differential Revision: https://reviews.llvm.org/D155280

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 807a63d36f39364c807ed1c3c71e2112cdcca843..0a32afd0e19fe942098293bcbe80deb520d9629a 100644 (file)
@@ -58,10 +58,12 @@ def OneShotBufferizeOp
       DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
       DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
       DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
-      DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);
+      DefaultValuedAttr<BoolAttr, "false">:$print_conflicts,
+      DefaultValuedAttr<StrAttr, "\"memref.copy\"">:$memcpy_op);
 
   let results = (outs TransformHandleTypeInterface:$transformed);
 
+  let hasVerifier = 1;
   let assemblyFormat = [{
     (`layout` `{` $function_boundary_type_conversion^ `}`)?
     $target attr-dict `:` functional-type($target, results)
index 9c23ad6bfd9022bde1a7d42bec1a49859bf1bfe7..f866484f4856783edfbeca219ff790ac2c046b70 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -25,6 +26,12 @@ using namespace mlir::transform;
 // OneShotBufferizeOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult transform::OneShotBufferizeOp::verify() {
+  if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
+    return emitOpError() << "unsupported memcpy op";
+  return success();
+}
+
 DiagnosedSilenceableFailure
 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
                                      TransformResults &transformResults,
@@ -39,6 +46,19 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   if (getFunctionBoundaryTypeConversion().has_value())
     options.setFunctionBoundaryTypeConversion(
         *getFunctionBoundaryTypeConversion());
+  if (getMemcpyOp() == "memref.copy") {
+    options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+      b.create<memref::CopyOp>(loc, from, to);
+      return success();
+    };
+  } else if (getMemcpyOp() == "linalg.copy") {
+    options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+      b.create<linalg::CopyOp>(loc, from, to);
+      return success();
+    };
+  } else {
+    llvm_unreachable("invalid copy op");
+  }
 
   auto payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {
index 51e5b0a099280b623d137e55f9ab7efa87f05c51..10ddabd7d840159e131a06248e3e77174c724056 100644 (file)
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
   MLIRIR
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRLinalgDialect
   MLIRParser
   MLIRPDLDialect
   MLIRSideEffectInterfaces
index c4a40448919437afa32508cde9e238fdd819b358..94550c8d4374a511c08cc6fb2f6f8edcde8e1700 100644 (file)
@@ -28,6 +28,35 @@ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf3
 
 // -----
 
+// Emit linalg.copy instead of memref.copy.
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.bufferization.one_shot_bufferize %0 {memcpy_op = "linalg.copy"} : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @test_function(
+//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
+//   CHECK-NOT:   memref.copy
+func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+
+  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+  // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+  // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+  %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+  // CHECK: memref.dealloc %[[alloc]]
+  // CHECK: return %[[res_tensor]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 // Test analysis of One-Shot Bufferize only.
 
 transform.sequence failures(propagate) {
index 2097da8a1e0b18a24aa6806a1d1c591c6ee59787..14027fcb038ce571625bcdfc9b0bb7644a41cd7c 100644 (file)
@@ -11477,6 +11477,7 @@ cc_library(
         ":BufferizationTransformOpsIncGen",
         ":BufferizationTransforms",
         ":IR",
+        ":LinalgDialect",
         ":MemRefDialect",
         ":Parser",
         ":SideEffectInterfaces",