[mlir] Lower gpu.memcpy to GPU runtime calls.
authorChristian Sigg <csigg@google.com>
Tue, 22 Dec 2020 16:42:59 +0000 (17:42 +0100)
committerChristian Sigg <csigg@google.com>
Tue, 22 Dec 2020 21:49:19 +0000 (22:49 +0100)
Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D93204

mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir [new file with mode: 0644]
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

index 3b4b39e..41a079c 100644 (file)
@@ -151,6 +151,12 @@ protected:
       "mgpuMemFree",
       llvmVoidType,
       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
+  FunctionCallBuilder memcpyCallBuilder = {
+      "mgpuMemcpy",
+      llvmVoidType,
+      {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
+       llvmIntPtrType /* intptr_t sizeBytes */,
+       llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -268,6 +274,20 @@ class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
     return success();
   }
 };
+
+/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertMemcpyOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
+public:
+  ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
@@ -643,6 +663,50 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
+
+  if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) ||
+      !isSupportedMemRefType(memRefType) ||
+      failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
+    return failure();
+
+  auto loc = memcpyOp.getLoc();
+  auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
+
+  MemRefDescriptor srcDesc(adaptor.src());
+
+  Value numElements =
+      memRefType.hasStaticShape()
+          ? createIndexConstant(rewriter, loc, memRefType.getNumElements())
+          // For identity layouts (verified above), the number of elements is
+          // stride[0] * size[0].
+          : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
+                                         srcDesc.size(rewriter, loc, 0));
+
+  Type elementPtrType = getElementPtrType(memRefType);
+  Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
+  Value gepPtr = rewriter.create<LLVM::GEPOp>(
+      loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
+  auto sizeBytes =
+      rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+
+  auto src = rewriter.create<LLVM::BitcastOp>(
+      loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
+  auto dst = rewriter.create<LLVM::BitcastOp>(
+      loc, llvmPointerType,
+      MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
+
+  auto stream = adaptor.asyncDependencies().front();
+  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
+
+  rewriter.replaceOp(memcpyOp, {stream});
+
+  return success();
+}
+
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
   return std::make_unique<GpuToLLVMConversionPass>(gpuBinaryAnnotation);
@@ -658,6 +722,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
   patterns.insert<ConvertAllocOpToGpuRuntimeCallPattern,
                   ConvertDeallocOpToGpuRuntimeCallPattern,
                   ConvertHostRegisterOpToGpuRuntimeCallPattern,
+                  ConvertMemcpyOpToGpuRuntimeCallPattern,
                   ConvertWaitAsyncOpToGpuRuntimeCallPattern,
                   ConvertWaitOpToGpuRuntimeCallPattern>(converter);
   patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
new file mode 100644 (file)
index 0000000..790c92f
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @foo
+  func @foo(%dst : memref<7xf32, 1>, %src : memref<7xf32>) {
+    // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+    // CHECK: %[[src:.*]] = llvm.bitcast
+    // CHECK: %[[dst:.*]] = llvm.bitcast
+    // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[t0]])
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<7xf32, 1>, memref<7xf32>
+    // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+    // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+    gpu.wait [%t1]
+    return
+  }
+}
index a6729b1..72d1728 100644 (file)
@@ -117,6 +117,13 @@ extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
   CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
 }
 
+extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+                           CUstream stream) {
+  CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
+                                     reinterpret_cast<CUdeviceptr>(src),
+                                     sizeBytes, stream));
+}
+
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the CUDA runtime. Helpful until we have
index aad7ae2..4f62f20 100644 (file)
@@ -118,6 +118,11 @@ extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
   HIP_REPORT_IF_ERROR(hipMemFree(ptr));
 }
 
+extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+                           hipStream_t stream) {
+  HIP_REPORT_IF_ERROR(hipMemcpyAsync(dst, src, sizeBytes, stream));
+}
+
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the ROCM runtime. Helpful until we have