[mlir] Add gpu.memcpy op.
authorChristian Sigg <csigg@google.com>
Tue, 22 Dec 2020 16:39:00 +0000 (17:39 +0100)
committerChristian Sigg <csigg@google.com>
Tue, 22 Dec 2020 16:39:55 +0000 (17:39 +0100)
Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D93197

mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir

index 953a2d5..457477f 100644 (file)
@@ -879,4 +879,39 @@ def GPU_DeallocOp : GPU_Op<"dealloc", [
   }];
 }
 
+def GPU_MemcpyOp : GPU_Op<"memcpy", [
+    GPU_AsyncOpInterface, MemoryEffects<[MemRead, MemWrite]>
+  ]> {
+
+  let summary = "GPU memcpy operation";
+
+  let description = [{
+    The `gpu.memcpy` operation copies the content of one memref to another.
+
+    The op does not execute before all async dependencies have finished
+    executing.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a !gpu.async.token.
+
+    Example:
+
+    ```mlir
+    %token = gpu.memcpy async [%dep] %dst, %src : memref<?xf32, 1>, memref<?xf32>
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   Arg<AnyMemRef, "", [MemWrite]>:$dst,
+                   Arg<AnyMemRef, "", [MemRead]>:$src);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $dst`,` $src `:` type($dst)`,` type($src) attr-dict
+  }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
 #endif // GPU_OPS
index e8a90ac..d3fa2cc 100644 (file)
@@ -22,6 +22,7 @@
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -842,6 +843,23 @@ static void print(OpAsmPrinter &p, GPUModuleOp op) {
                 /*printBlockTerminators=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// GPUMemcpyOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MemcpyOp op) {
+  auto srcType = op.src().getType();
+  auto dstType = op.dst().getType();
+
+  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
+    return op.emitOpError("arguments have incompatible element type");
+
+  if (failed(verifyCompatibleShape(srcType, dstType)))
+    return op.emitOpError("arguments have incompatible shape");
+
+  return success();
+}
+
 static ParseResult parseAsyncDependencies(
     OpAsmParser &parser, Type &asyncTokenType,
     SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) {
index 3dc5be4..1f6058c 100644 (file)
@@ -444,3 +444,17 @@ func @async_wait_without_result() {
   // expected-error @+1 {{custom op 'gpu.wait' needs to be named when marked 'async'}}
   gpu.wait async
 }
+
+// -----
+
+func @memcpy_incompatible_type(%dst : memref<?xf32>, %src : memref<?xi32>) {
+  // expected-error @+1 {{'gpu.memcpy' op arguments have incompatible element type}}
+  gpu.memcpy %dst, %src  : memref<?xf32>, memref<?xi32>
+}
+
+// -----
+
+func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
+  // expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}}
+  gpu.memcpy %dst, %src  : memref<7xf32>, memref<9xf32>
+}
index aed4368..5cea772 100644 (file)
@@ -183,4 +183,15 @@ module attributes {gpu.container_module} {
     gpu.wait // Valid, but a no-op.
     return
   }
+
+  func @memcpy(%dst : memref<3x7xf32>, %src : memref<3x7xf32, 1>) {
+    // CHECK-LABEL: func @memcpy
+    // CHECK: gpu.memcpy {{.*}}, {{.*}} : memref<3x7xf32>, memref<3x7xf32, 1>
+    gpu.memcpy %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
+    // CHECK: %[[t0:.*]] = gpu.wait async
+    %0 = gpu.wait async
+    // CHECK: {{.*}} = gpu.memcpy async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, memref<3x7xf32, 1>
+    %1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
+    return
+  }
 }