[mlir][nvgpu] Add `mbarrier.arrive.expect_tx` and `mbarrier.try_wait.parity`
authorGuray Ozen <guray.ozen@gmail.com>
Thu, 20 Jul 2023 10:26:35 +0000 (12:26 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Thu, 20 Jul 2023 11:48:30 +0000 (13:48 +0200)
This work adds two Ops:
`mbarrier.arrive.expect_tx` performs expect_tx `mbarrier.barrier` returns `mbarrier.barrier.token`
`mbarrier.try_wait.parity` waits on `mbarrier.barrier` and `mbarrier.barrier.token`

`mbarrier.arrive.expect_tx` is one of the requirement to enable H100 TMA support.

Depends on D154074 D154076 D154059 D154060

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D154094

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

index ef17a6c..c3cafe6 100644 (file)
@@ -372,53 +372,59 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
 }
 
 def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Results<(outs LLVM_Type:$res)>,
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
   let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); }
+    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
   }];
 }
 
 def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Results<(outs LLVM_Type:$res)>,
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
   let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); }
+    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
   }];
 }
 
 def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {
-  let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {  
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string("{\n\t"
-              ".reg .pred P1; \n\t"
-              "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t"
-              "selp.b32 %0, 1, 0, P1; \n\t"
-              "}"); 
+      return std::string(
+        "{\n\t"
+        ".reg .pred       P1; \n\t"
+        "LAB_WAIT: \n\t"
+        "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
+        "@P1 bra.uni DONE; \n\t"
+        "bra.uni     LAB_WAIT; \n\t"
+        "DONE: \n\t"
+        "}"
+      ); 
     }
   }];
 }
 
 def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> {  
-  let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {  
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string("{\n\t"
-              ".reg .pred P1; \n\t"
-              "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
-              "selp.b32 %0, 1, 0, P1; \n\t"
-              "}"); 
+      return std::string(
+        "{\n\t"
+        ".reg .pred       P1; \n\t"
+        "LAB_WAIT: \n\t"
+        "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
+        "@P1 bra.uni DONE; \n\t"
+        "bra.uni     LAB_WAIT; \n\t"
+        "DONE: \n\t"
+        "}"
+      ); 
     }
   }];
 }
index 9e783d0..eb0bdee 100644 (file)
@@ -469,4 +469,44 @@ def NVGPU_MBarrierArriveNoCompleteOp : NVGPU_Op<"mbarrier.arrive.nocomplete", []
   let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)";
 }
 
+def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
+  let summary = "Performs expect_tx operation on the `nvgpu.mbarrier.arrive`";
+  let description = [{
+    A thread executing the Op performs an expect-tx operation on the mbarrier 
+    object at the location specified by the address operand $barrier. The 
+    expect-tx operation, with an $txcount argument, increases the tx-count of 
+    an mbarrier object by the value specified by $txcount. This makes the 
+    current phase of the mbarrier object to expect and track the completion of 
+    additional asynchronous transactions.
+    
+    The `$txCount` specifies the number of element to the expect-tx operation.
+
+    Example:
+    ```mlir
+      nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+    ```
+  }];
+  let arguments = (ins NVGPU_MBarrier:$barrier,
+                       Index:$txcount);  
+  let assemblyFormat = "$barrier `,` $txcount  attr-dict `:` type($barrier)";
+}
+
+def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
+  let summary = "Waits for the `nvgpu.mbarrier` to complete its current phase.";
+  let description = [{
+    Checks whether the mbarrier object has completed the phase. It is is a 
+    potentially blocking instruction which tests for the completion of the 
+    phase. Suspended thread resumes execution when the specified phase completes 
+    OR before the phase completes following a system-dependent time limit. 
+
+    Example:
+    ```mlir
+      nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+    ```
+
+  }];
+  let arguments = (ins NVGPU_MBarrier:$barrier, Index:$phase, Index:$ticks);
+  let assemblyFormat = "$barrier `,` $phase `,` $ticks attr-dict `:` type($barrier)";  
+}
+
 #endif // NVGPU
index b8adef2..26be5c0 100644 (file)
@@ -25,6 +25,17 @@ namespace mlir {
 
 using namespace mlir;
 
+/// GPU has 32 bit registers, this function truncates values when larger width
+/// is not needed.
+static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
+                        Value value) {
+  Type type = value.getType();
+  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
+  if (type.getIntOrFloatBitWidth() <= 32)
+    return value;
+  return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+}
+
 /// Returns the type for the intrinsic given the vectorResultType of the
 /// `gpu.mma.sync` operation.
 static Type inferIntrinsicResultType(Type vectorResultType) {
@@ -850,6 +861,55 @@ struct NVGPUMBarrierTestWaitLowering
   }
 };
 
+struct NVGPUMBarrierArriveExpectTxLowering
+    : public ConvertOpToLLVMPattern<nvgpu::MBarrierArriveExpectTxOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
+                                   op.getBarrier(), adaptor.getBarrier());
+    Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
+
+    if (isMbarrierShared(op.getBarrier().getType())) {
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
+          op, barrier, txcount);
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
+                                                                txcount);
+    return success();
+  }
+};
+
+struct NVGPUMBarrierTryWaitParityLowering
+    : public ConvertOpToLLVMPattern<nvgpu::MBarrierTryWaitParityOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::MBarrierTryWaitParityOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
+                                   op.getBarrier(), adaptor.getBarrier());
+    Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
+    Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
+
+    if (isMbarrierShared(op.getBarrier().getType())) {
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
+          op, barrier, phase, ticks);
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
+                                                               phase, ticks);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -859,7 +919,9 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
       NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
-      NVGPUMBarrierTestWaitLowering,         // nvgpu.try_wait_parity
+      NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
+      NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
+      NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
index 7a7f65f..c7a0c7f 100644 (file)
@@ -558,3 +558,49 @@ func.func @mbarrier_nocomplete() {
 
   func.return 
 }
+
+
+// -----
+!barrierType = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+!tokenType = !nvgpu.mbarrier.token
+
+// CHECK-LABEL: func @mbarrier_txcount
+func.func @mbarrier_txcount() {
+      %num_threads = arith.constant 128 : index
+
+    // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier : memref<1xi64, 3>
+    %barrier = nvgpu.mbarrier.create -> !barrierType
+
+    // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: %[[barPtr:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
+    nvgpu.mbarrier.init %barrier, %num_threads : !barrierType
+    
+    %c0 = arith.constant 0 : index  
+    %tidxreg = nvvm.read.ptx.sreg.tid.x : i32
+    %tidx = arith.index_cast %tidxreg : i32 to index
+    %cnd = arith.cmpi eq, %tidx, %c0 : index  
+
+    scf.if %cnd {
+      %txcount = arith.constant 256 : index
+      // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+      nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType
+      scf.yield 
+    } else {
+      %txcount = arith.constant 0 : index
+      // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+      nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType
+      scf.yield 
+    }
+      
+
+    %phase = arith.constant 0 : index
+    %ticks = arith.constant 10000000 : index
+    // CHECK: %[[barPtr3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !barrierType
+
+    func.return 
+}
\ No newline at end of file
index 5d3218e..0d93072 100644 (file)
@@ -1,31 +1,31 @@
 // RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL : @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 {
-  //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64            
-  %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
-  llvm.return %res : i64
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
+  //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" 
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+  llvm.return
 }
 
 // CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
-  %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
-  llvm.return %res : i64
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+  llvm.return
 }
 
 // CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
-  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
-  %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
-  llvm.return %res : i32
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred       P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni     LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r"
+   nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+  llvm.return
 }
 
 // CHECK-LABEL : @init_mbarrier_try_wait.parity
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
-  %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
-  llvm.return %res : i32
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred       P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni     LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r"
+  nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
+  llvm.return
 }
 
 // CHECK-LABEL : @async_cp