[mlir][nvvm] Implement `mbarrier.init`
authorGuray Ozen <guray.ozen@gmail.com>
Fri, 16 Jun 2023 08:03:30 +0000 (10:03 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Fri, 16 Jun 2023 11:35:14 +0000 (13:35 +0200)
NV GPUs provides split arrive/wait barriers that one can syncronize a subgroup of threads in CTA. It is particularly important for Hopper GPUs and allows tracking engines like TMA. See for more details:
https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier

This initial implementation sets the foundation for future enhancements and additions.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D151334

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/test/Dialect/LLVMIR/nvvm.mlir

index 5dcd5f9..118e784 100644 (file)
@@ -19,6 +19,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
 def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
+def LLVM_i64ptr_any : LLVM_IntPtrBase<64>;
+def LLVM_i64ptr_shared : LLVM_IntPtrBase<64, 3>;
 
 //===----------------------------------------------------------------------===//
 // NVVM dialect definitions
@@ -174,6 +176,28 @@ def NVVM_ReduxOp :
 }
 
 //===----------------------------------------------------------------------===//
+// NVVM Split arrive/wait barrier
+//===----------------------------------------------------------------------===//
+
+/// mbarrier.init instruction with generic pointer type
+def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
+  string llvmBuilder = [{
+      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
+  }];
+  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+}
+
+/// mbarrier.init instruction with shared pointer type
+def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
+  string llvmBuilder = [{
+      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
+  }];
+  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+}
+
+//===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//
 
index c7c83d2..d08d02a 100644 (file)
@@ -337,3 +337,19 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
 
 // expected-error@below {{attribute attached to unexpected op}}
 func.func private @expected_llvm_func() attributes { nvvm.kernel }
+
+// -----
+llvm.func private @mbarrier_init_generic(%barrier: !llvm.ptr) {
+  %count = nvvm.read.ptx.sreg.ntid.x : i32
+  // CHECK:   nvvm.mbarrier.init %{{.*}}, %{{.*}} : !llvm.ptr, i32
+  nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
+  llvm.return
+}
+
+
+llvm.func private @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
+  %count = nvvm.read.ptx.sreg.ntid.x : i32
+  // CHECK:   nvvm.mbarrier.init.shared %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32
+  nvvm.mbarrier.init.shared %barrier, %count : !llvm.ptr<3>, i32
+  llvm.return
+}