[mlir] create gpu memset op
authorLoren Maggiore <loreno@google.com>
Sat, 4 Sep 2021 06:03:33 +0000 (08:03 +0200)
committerChristian Sigg <csigg@google.com>
Sat, 4 Sep 2021 06:13:04 +0000 (08:13 +0200)
Create a gpu memset op and corresponding CUDA and ROCm wrappers.

Reviewed By: herhut, lorenrose1013

Differential Revision: https://reviews.llvm.org/D107548

mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir [new file with mode: 0644]
mlir/test/Dialect/GPU/canonicalize.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir

index facea7e..9290e16 100644 (file)
@@ -901,6 +901,42 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
   let hasFolder = 1;
 }
 
+def GPU_MemsetOp : GPU_Op<"memset",
+  [GPU_AsyncOpInterface, AllElementTypesMatch<["dst", "value"]>]> {
+
+  let summary = "GPU memset operation";
+
+  let description = [{
+    The `gpu.memset` operation sets the content of memref to a scalar value.
+
+    The op does not execute before all async dependencies have finished
+    executing.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a !gpu.async.token.
+
+    Example:
+
+    ```mlir
+    %token = gpu.memset async [%dep] %dst, %value : memref<?xf32, 1>, f32
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   Arg<AnyMemRef, "", [MemWrite]>:$dst,
+                   Arg<AnyType, "">:$value);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $dst`,` $value `:` type($dst)`,` type($value) attr-dict
+  }];
+  // MemsetOp is fully verified by traits.
+  let verifier = [{ return success(); }];
+  let hasFolder = 1;
+}
+
 def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     [MemoryEffects<[MemRead]>]>{
 
index 1234a9a..e4f7858 100644 (file)
@@ -79,6 +79,18 @@ public:
       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
 
 protected:
+  Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
+                       MemRefType type, MemRefDescriptor desc) const {
+    return type.hasStaticShape()
+               ? ConvertToLLVMPattern::createIndexConstant(
+                     rewriter, loc, type.getNumElements())
+               // For identity maps (verified by caller), the number of
+               // elements is stride[0] * size[0].
+               : rewriter.create<LLVM::MulOp>(loc,
+                                              desc.stride(rewriter, loc, 0),
+                                              desc.size(rewriter, loc, 0));
+  }
+
   MLIRContext *context = &this->getTypeConverter()->getContext();
 
   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
@@ -165,6 +177,12 @@ protected:
       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
        llvmIntPtrType /* intptr_t sizeBytes */,
        llvmPointerType /* void *stream */}};
+  FunctionCallBuilder memsetCallBuilder = {
+      "mgpuMemset32",
+      llvmVoidType,
+      {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
+       llvmIntPtrType /* intptr_t sizeBytes */,
+       llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -308,6 +326,20 @@ private:
   matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
+
+/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertMemsetOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
+public:
+  ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
@@ -757,14 +789,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
 
   MemRefDescriptor srcDesc(adaptor.src());
-
-  Value numElements =
-      memRefType.hasStaticShape()
-          ? createIndexConstant(rewriter, loc, memRefType.getNumElements())
-          // For identity layouts (verified above), the number of elements is
-          // stride[0] * size[0].
-          : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
-                                         srcDesc.size(rewriter, loc, 0));
+  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
 
   Type elementPtrType = getElementPtrType(memRefType);
   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
@@ -787,6 +812,40 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
+
+  if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
+      !isConvertibleAndHasIdentityMaps(memRefType) ||
+      failed(isAsyncWithOneDependency(rewriter, memsetOp)))
+    return failure();
+
+  auto loc = memsetOp.getLoc();
+  auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());
+
+  Type valueType = adaptor.value().getType();
+  if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
+    return rewriter.notifyMatchFailure(memsetOp,
+                                       "value must be a 32 bit scalar");
+  }
+
+  MemRefDescriptor dstDesc(adaptor.dst());
+  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
+
+  auto value =
+      rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
+  auto dst = rewriter.create<LLVM::BitcastOp>(
+      loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
+
+  auto stream = adaptor.asyncDependencies().front();
+  memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
+
+  rewriter.replaceOp(memsetOp, {stream});
+  return success();
+}
+
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 mlir::createGpuToLLVMConversionPass() {
   return std::make_unique<GpuToLLVMConversionPass>();
@@ -803,6 +862,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
                ConvertDeallocOpToGpuRuntimeCallPattern,
                ConvertHostRegisterOpToGpuRuntimeCallPattern,
                ConvertMemcpyOpToGpuRuntimeCallPattern,
+               ConvertMemsetOpToGpuRuntimeCallPattern,
                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
                ConvertWaitOpToGpuRuntimeCallPattern,
                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
index 1668276..f148d6e 100644 (file)
@@ -1079,6 +1079,11 @@ LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
   return foldMemRefCast(*this);
 }
 
+LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
+                             SmallVectorImpl<::mlir::OpFoldResult> &results) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // GPU_AllocOp
 //===----------------------------------------------------------------------===//
index 8220f03..9ee8057 100644 (file)
@@ -141,13 +141,19 @@ extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
   CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
 }
 
-extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
                            CUstream stream) {
   CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
                                      reinterpret_cast<CUdeviceptr>(src),
                                      sizeBytes, stream));
 }
 
+extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
+                             CUstream stream) {
+  CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
+                                        value, count, stream));
+}
+
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the CUDA runtime. Helpful until we have
index 399a373..92358ed 100644 (file)
@@ -133,12 +133,17 @@ extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
   HIP_REPORT_IF_ERROR(hipFree(ptr));
 }
 
-extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
                            hipStream_t stream) {
   HIP_REPORT_IF_ERROR(
       hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
 }
 
+extern "C" void mgpuMemset32(void *dst, int value, size_t count,
+                             hipStream_t stream) {
+  HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
+                                        value, count, stream));
+}
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the ROCM runtime. Helpful until we have
diff --git a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
new file mode 100644 (file)
index 0000000..1b786ed
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @foo
+  func @foo(%dst : memref<7xf32, 1>, %value : f32) {
+    // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
+    // CHECK: %[[value:.*]] = llvm.bitcast
+    // CHECK: %[[dst:.*]] = llvm.bitcast
+    // CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
+    %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32
+    // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+    // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+    gpu.wait [%t1]
+    return
+  }
+}
index 448e08b..38abfd1 100644 (file)
@@ -6,7 +6,16 @@ func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
   // CHECK: gpu.memcpy
   %0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
   %1 = memref.cast %arg1 : memref<10xf32> to memref<?xf32>
-  gpu.memcpy %0,%1 : memref<?xf32>, memref<?xf32>
+  gpu.memcpy %0, %1 : memref<?xf32>, memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: @memset_after_cast
+func @memset_after_cast(%arg0: memref<10xf32>, %arg1: f32) {
+  // CHECK-NOT: memref.cast
+  // CHECK: gpu.memset
+  %0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
+  gpu.memset %0, %arg1 : memref<?xf32>, f32
   return
 }
 
index a1ed9fa..3c7a57d 100644 (file)
@@ -467,6 +467,13 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
 
 // -----
 
+func @memset_incompatible_shape(%dst : memref<?xf32>, %value : i32) {
+  // expected-error @+1 {{'gpu.memset' op failed to verify that all of {dst, value} have same element type}}
+  gpu.memset %dst, %value  : memref<?xf32>, i32
+}
+
+// -----
+
 func @mmamatrix_invalid_shape(){
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
index 1bed13c..2c4a13d 100644 (file)
@@ -195,6 +195,17 @@ module attributes {gpu.container_module} {
     return
   }
 
+  func @memset(%dst : memref<3x7xf32>, %value : f32) {
+    // CHECK-LABEL: func @memset
+    // CHECK: gpu.memset {{.*}}, {{.*}} : memref<3x7xf32>, f32
+    gpu.memset %dst, %value : memref<3x7xf32>, f32
+    // CHECK: %[[t0:.*]] = gpu.wait async
+    %0 = gpu.wait async
+    // CHECK: {{.*}} = gpu.memset async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, f32
+    %1 = gpu.memset async [%0] %dst, %value : memref<3x7xf32>, f32
+    return
+  }
+
   func @mmamatrix_valid_element_type(){
     // CHECK-LABEL: func @mmamatrix_valid_element_type
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>