From bf62748342438d7136ca78ef3875b31442b1ccd3 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 19 Jul 2023 10:29:41 +0200 Subject: [PATCH] [mlir][nvvm] Introduce Syncronization Ops for WGMMA This work introduces : `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned` and `wgmma.wait.group.sync.aligned` Ops. They are used to syncronize warpgroup level matrix multiply-accumulate instructions, as known as WGMMA. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D155676 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 47 +++++++++++++++++++++++ mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 20 ++++++++++ mlir/test/Dialect/LLVMIR/nvvm.mlir | 22 +++++++++++ 3 files changed, 89 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index efdd3d6..ef17a6c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1419,4 +1419,51 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// NVVM Wgmma Ops +//===----------------------------------------------------------------------===// + +def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", + [DeclareOpInterfaceMethods]> { + let arguments = (ins); + let description = [{ + Enforce an ordering of register accesses between warpgroup level matrix + multiplication and other operations. + See for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence + }]; + let assemblyFormat = "attr-dict"; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); } + }]; +} + +def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", + [DeclareOpInterfaceMethods]>, + Arguments<(ins )> { + let assemblyFormat = "attr-dict"; + let description = [{ + Commits all prior uncommitted warpgroup level matrix multiplication operations. + See for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group + }]; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); } + }]; +} + +def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", + [DeclareOpInterfaceMethods]>{ + let arguments = (ins I32Attr:$group); + let assemblyFormat = "attr-dict $group"; + let description = [{ + Signal the completion of a preceding warpgroup operation. + See for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group + }]; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); } + }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 4201c7b..5d3218e 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -80,3 +80,23 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32 return } + + +// CHECK-LABEL : @wgmma_execute +func.func @wgmma_execute() { + nvvm.wgmma.fence.aligned + nvvm.wgmma.commit.group.sync.aligned + nvvm.wgmma.wait.group.sync.aligned 0 + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", "" + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", "" + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32) + + + nvvm.wgmma.fence.aligned + nvvm.wgmma.commit.group.sync.aligned + nvvm.wgmma.wait.group.sync.aligned 1 + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", "" + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", "" + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32) + return +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index bbc7676..b26f3b0 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -407,3 +407,25 @@ llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i6 %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1 llvm.return } + +// CHECK-LABEL : @wgmma_fence_aligned +func.func @wgmma_fence_aligned() { + // CHECK : nvvm.wgmma.fence.aligned + nvvm.wgmma.fence.aligned + return +} + +// CHECK-LABEL : @wgmma_commit_group_sync_aligned +func.func @wgmma_commit_group_sync_aligned() { + // CHECK : nvvm.wgmma.commit.group.sync.aligned + nvvm.wgmma.commit.group.sync.aligned + return +} + + +// CHECK-LABEL : @wgmma_commit_group_sync_aligned +func.func @wgmma_wait_group_sync_aligned() { + // CHECK : nvvm.wgmma.wait.group.sync.aligned + nvvm.wgmma.wait.group.sync.aligned 0 + return +} -- 2.7.4