[mlir][memref] Add memref.copy operation
authorStephan Herhut <herhut@google.com>
Mon, 21 Jun 2021 17:33:28 +0000 (19:33 +0200)
committerStephan Herhut <herhut@google.com>
Tue, 22 Jun 2021 11:21:44 +0000 (13:21 +0200)
As the name suggests, it copies from one memref to another.

Differential Revision: https://reviews.llvm.org/D104657

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir

index 7bcc53d..19f1c3c 100644 (file)
@@ -445,6 +445,43 @@ def CloneOp : MemRef_Op<"clone", [
 }
 
 //===----------------------------------------------------------------------===//
+// CopyOp
+//===----------------------------------------------------------------------===//
+
+def CopyOp : MemRef_Op<"copy",
+    [CopyOpInterface, SameOperandsElementType, SameOperandsShape]> {
+
+  let description = [{
+    Copies the data from the source to the destination memref.
+
+    Usage:
+
+    ```mlir
+    memref.copy %arg0, %arg1 : memref<?xf32> to memref<?xf32>
+    ```
+
+    Source and destination are expected to have the same element type and shape.
+    Otherwise, the result is undefined. They may have different layouts.
+  }];
+
+  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "the memref to copy from",
+                           [MemRead]>:$source,
+                       Arg<AnyRankedOrUnrankedMemRef, "the memref to copy to",
+                           [MemWrite]>:$target);
+
+  let extraClassDeclaration = [{
+    Value getSource() { return source();}
+    Value getTarget() { return target(); }
+  }];
+
+  let assemblyFormat = [{
+    $source `,` $target attr-dict `:` type($source) `to` type($target)
+  }];
+
+  let verifier = ?;
+}
+
+//===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
index 06c1b20..63209ef 100644 (file)
@@ -215,3 +215,19 @@ func @mismatched_types() {
   %0 = memref.get_global @gv : memref<3xf32>
   return
 }
+
+// -----
+
+func @copy_different_shape(%arg0: memref<2xf32>, %arg1: memref<3xf32>) {
+  // expected-error @+1 {{op requires the same shape for all operands}}
+  memref.copy %arg0, %arg1 : memref<2xf32> to memref<3xf32>
+  return
+}
+
+// -----
+
+func @copy_different_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
+  // expected-error @+1 {{op requires the same element type for all operands}}
+  memref.copy %arg0, %arg1 : memref<2xf32> to memref<2xf16>
+  return
+}
index bbd7fb3..993a613 100644 (file)
@@ -69,6 +69,16 @@ func @memref_clone() {
   return
 }
 
+// CHECK-LABEL: func @memref_copy
+func @memref_copy() {
+  %0 = memref.alloc() : memref<2xf32>
+  %1 = memref.cast %0 : memref<2xf32> to memref<*xf32>
+  %2 = memref.alloc() : memref<2xf32>
+  %3 = memref.cast %0 : memref<2xf32> to memref<*xf32>
+  memref.copy %1, %3 : memref<*xf32> to memref<*xf32>
+  return
+}
+
 // CHECK-LABEL: func @memref_dealloc
 func @memref_dealloc() {
   %0 = memref.alloc() : memref<2xf32>