`bypassL1` attribute is hint to the backend and hardware that
the copy should by pass the L1 cache, this may be dropped by the backend or
hardware.
+ `dstElements` attribute is the total number of elements written to
+ destination (shared memory).
+ `srcElements` argument is the total number of elements read from
+ source (global memory).
+
+ srcElements` is an optional argument and when present it only reads
+ srcElements number of elements from the source global memory and zero fills
+ the rest of the elements in the destination shared memory.
In order to do a copy and wait for the result we need the following
combination:
Variadic<Index>:$dstIndices,
Arg<AnyMemRef, "", [MemRead]>:$src,
Variadic<Index>:$srcIndices,
- IndexAttr:$numElements,
+ IndexAttr:$dstElements,
+ Optional<Index>:$srcElements,
OptionalAttr<UnitAttr>:$bypassL1);
let assemblyFormat = [{
- $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
attr-dict `:` type($src) `to` type($dst)
}];
let hasVerifier = 1;
}
};
+static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
+ Value dstBytes, Value srcElements,
+ mlir::MemRefType elementType,
+ ConversionPatternRewriter &rewriter) {
+ auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+ LLVM::AsmDialect::AD_ATT);
+ const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+ const char *asmConstraints = "r,l,n,r";
+
+ Value c3I32 = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
+ Value bitwidth = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth()));
+ Value srcElementsI32 =
+ rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcElements);
+ Value srcBytes = rewriter.create<LLVM::LShrOp>(
+ loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32), c3I32);
+
+ SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
+
+ rewriter.create<LLVM::InlineAsmOp>(
+ loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/asmConstraints, /*has_side_effects=*/true,
+ /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
+}
+
struct NVGPUAsyncCopyLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
using ConvertOpToLLVMPattern<
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
scrPtr);
- int64_t numElements = adaptor.getNumElements().getZExtValue();
+ int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
- (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
+ (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
// bypass L1 is only supported for byte sizes of 16, we drop the hint
// otherwise.
UnitAttr bypassL1 =
sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
- rewriter.create<NVVM::CpAsyncOp>(
- loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
+
+ // When the optional SrcElements argument is present, the source (global
+ // memory) of CpAsyncOp is read only for SrcElements number of elements. The
+ // rest of the DstElements in the destination (shared memory) are filled
+ // with zeros.
+ if (op.getSrcElements())
+ emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr,
+ rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(sizeInBytes)),
+ adaptor.getSrcElements(), srcMemrefType, rewriter);
+
+ // When the optional SrcElements argument is *not* present, the regular
+ // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
+ // memory) to fill DstElements number of elements in the destination (shared
+ // memory).
+ else
+ rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
+ rewriter.getI32IntegerAttr(sizeInBytes),
+ bypassL1);
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
return %0 : !nvgpu.device.async.token
}
+// -----
+
+// CHECK-LABEL: @async_cp_zfill(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp_zfill(
+ %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+
+ // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>, i32, i32) -> !llvm.void
+ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+ // CHECK: nvvm.cp.async.commit.group
+ %1 = nvgpu.device_async_create_group %0
+ // CHECK: nvvm.cp.async.wait.group 1
+ nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+ return
+}