[mlir][nvvm] Add attribute to nvvm.cpAsyncOp to control l1 bypass
authorThomas Raoux <thomasraoux@google.com>
Mon, 9 May 2022 15:46:28 +0000 (15:46 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 9 May 2022 19:34:48 +0000 (19:34 +0000)
Add attribute to be able to generate the intrinsic version of async copy
generating a copy with l1 bypass. This correspond to
cp.async.cg.shared.global in ptx.

Differential Revision: https://reviews.llvm.org/D125241

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir

index f9d32f4..f19500e 100644 (file)
@@ -153,7 +153,8 @@ def NVVM_VoteBallotOp :
 def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_i8Ptr_shared:$dst,
                  LLVM_i8Ptr_global:$src,
-                 I32Attr:$size)> {
+                 I32Attr:$size,
+                 OptionalAttr<UnitAttr>:$bypass_l1)> {
   string llvmBuilder = [{
       llvm::Intrinsic::ID id;
       switch ($size) {
@@ -164,7 +165,10 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
           id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
           break;
         case 16:
-          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
+          if(static_cast<bool>($bypass_l1))
+            id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
+          else
+            id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
           break;
         default:
           llvm_unreachable("unsupported async copy size");
index 4f65730..6ccc8a0 100644 (file)
@@ -164,7 +164,8 @@ struct GPUAsyncCopyLowering
     int64_t sizeInBytes =
         (dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
     rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
-                                     rewriter.getI32IntegerAttr(sizeInBytes));
+                                     rewriter.getI32IntegerAttr(sizeInBytes),
+                                     /*bypassL1=*/UnitAttr());
 
     // Drop the result token.
     Value zero = rewriter.create<LLVM::ConstantOp>(
index 345d900..640e84a 100644 (file)
@@ -67,6 +67,8 @@ void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
 LogicalResult CpAsyncOp::verify() {
   if (size() != 4 && size() != 8 && size() != 16)
     return emitError("expected byte size to be either 4, 8 or 16.");
+  if (bypass_l1() && size() != 16)
+    return emitError("bypass l1 is only support for 16 bytes copy.");
   return success();
 }
 
index 50b9f1b..876668d 100644 (file)
@@ -1261,6 +1261,14 @@ func.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
 
 // -----
 
+func.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+  // expected-error @below {{bypass l1 is only support for 16 bytes copy.}}
+  nvvm.cp.async.shared.global %arg0, %arg1, 8 {bypass_l1}
+  return
+}
+
+// -----
+
 func.func @gep_struct_variable(%arg0: !llvm.ptr<struct<(i32)>>, %arg1: i32, %arg2: i32) {
   // expected-error @below {{op expected index 1 indexing a struct to be constant}}
   llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr<struct<(i32)>>, i32, i32) -> !llvm.ptr<i32>
index dfe0443..728755d 100644 (file)
@@ -258,6 +258,8 @@ func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
 llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
 // CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
   nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
+  nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1}
 // CHECK: nvvm.cp.async.commit.group
   nvvm.cp.async.commit.group
 // CHECK: nvvm.cp.async.wait.group 0
index fddfdda..f3bd013 100644 (file)
@@ -287,6 +287,8 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
   nvvm.cp.async.shared.global %arg0, %arg1, 8
 // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+  nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1}
 // CHECK: call void @llvm.nvvm.cp.async.commit.group()
   nvvm.cp.async.commit.group
 // CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)