[mlir][NVGPU] Adding Support for cp_async_zfill via Inline Asm
authorManish Gupta <manigupta@google.com>
Fri, 2 Sep 2022 21:20:11 +0000 (21:20 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 2 Sep 2022 21:29:26 +0000 (21:29 +0000)
`cp_async_zfill` is currently not present in the nvvm backend, this patch adds `cp_async_zfill` support by adding inline asm when lowering from `nvgpu` to `nvvm`.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D132269

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

index d0dd5a6..3b92f54 100644 (file)
@@ -151,6 +151,14 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
     `bypassL1` attribute is hint to the backend and hardware that
     the copy should by pass the L1 cache, this may be dropped by the backend or
     hardware. 
+    `dstElements` attribute is the total number of elements written to 
+    destination (shared memory).
+    `srcElements` argument is the total number of elements read from 
+    source (global memory).
+    
+    srcElements` is an optional argument and when present it only reads 
+    srcElements number of elements from the source global memory and zero fills 
+    the rest of the elements in the destination shared memory.
 
     In order to do a copy and wait for the result we need the following
     combination:
@@ -183,10 +191,11 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
                        Variadic<Index>:$dstIndices,
                        Arg<AnyMemRef, "", [MemRead]>:$src,
                        Variadic<Index>:$srcIndices,
-                       IndexAttr:$numElements,
+                       IndexAttr:$dstElements,
+                       Optional<Index>:$srcElements,
                        OptionalAttr<UnitAttr>:$bypassL1);
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
       attr-dict `:` type($src) `to` type($dst)
   }];
   let hasVerifier = 1;
index 9f1d19d..c4c49f2 100644 (file)
@@ -354,6 +354,35 @@ struct ConvertNVGPUToNVVMPass
   }
 };
 
+static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
+                                  Value dstBytes, Value srcElements,
+                                  mlir::MemRefType elementType,
+                                  ConversionPatternRewriter &rewriter) {
+  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+                                                  LLVM::AsmDialect::AD_ATT);
+  const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+  const char *asmConstraints = "r,l,n,r";
+
+  Value c3I32 = rewriter.create<LLVM::ConstantOp>(
+      loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
+  Value bitwidth = rewriter.create<LLVM::ConstantOp>(
+      loc, rewriter.getI32Type(),
+      rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth()));
+  Value srcElementsI32 =
+      rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcElements);
+  Value srcBytes = rewriter.create<LLVM::LShrOp>(
+      loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32), c3I32);
+
+  SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
+
+  rewriter.create<LLVM::InlineAsmOp>(
+      loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals,
+      /*asm_string=*/asmStr,
+      /*constraints=*/asmConstraints, /*has_side_effects=*/true,
+      /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+      /*operand_attrs=*/ArrayAttr());
+}
+
 struct NVGPUAsyncCopyLowering
     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
   using ConvertOpToLLVMPattern<
@@ -383,15 +412,33 @@ struct NVGPUAsyncCopyLowering
         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
     scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
                                                     scrPtr);
-    int64_t numElements = adaptor.getNumElements().getZExtValue();
+    int64_t dstElements = adaptor.getDstElements().getZExtValue();
     int64_t sizeInBytes =
-        (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
+        (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
     // bypass L1 is only supported for byte sizes of 16, we drop the hint
     // otherwise.
     UnitAttr bypassL1 =
         sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
-    rewriter.create<NVVM::CpAsyncOp>(
-        loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
+
+    // When the optional SrcElements argument is present, the source (global
+    // memory) of CpAsyncOp is read only for SrcElements number of elements. The
+    // rest of the DstElements in the destination (shared memory) are filled
+    // with zeros.
+    if (op.getSrcElements())
+      emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr,
+                            rewriter.create<LLVM::ConstantOp>(
+                                loc, rewriter.getI32Type(),
+                                rewriter.getI32IntegerAttr(sizeInBytes)),
+                            adaptor.getSrcElements(), srcMemrefType, rewriter);
+
+    // When the optional SrcElements argument is *not* present, the regular
+    // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
+    // memory) to fill DstElements number of elements in the destination (shared
+    // memory).
+    else
+      rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
+                                       rewriter.getI32IntegerAttr(sizeInBytes),
+                                       bypassL1);
 
     // Drop the result token.
     Value zero = rewriter.create<LLVM::ConstantOp>(
index aa71a26..0a9f8d5 100644 (file)
@@ -297,3 +297,19 @@ func.func @async_cp_i4(
   return %0 : !nvgpu.device.async.token
 }
 
+// -----
+
+// CHECK-LABEL: @async_cp_zfill(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp_zfill(
+  %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+
+  // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>, i32, i32) -> !llvm.void
+  %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+  // CHECK: nvvm.cp.async.commit.group
+  %1 = nvgpu.device_async_create_group %0
+  // CHECK: nvvm.cp.async.wait.group 1
+  nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+  return
+}