From 36663626ee336905745cb1c259b3b65c9ff656bf Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Thu, 19 Jan 2023 10:09:02 +0100 Subject: [PATCH] [mlir][nvvm] Introduce redux op Ptx model has `redux.sync` that performs reduction operation on the data from each predicated active thread in the thread group. It only is available sm80+. This revision adds redux as on op to nvvm dialect. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D142088 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 39 ++++++++++++++++++++++ .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 26 +++++++++++++++ mlir/test/Dialect/LLVMIR/nvvm.mlir | 23 +++++++++++++ 3 files changed, 88 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8a16d64..289064c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -136,6 +136,45 @@ def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> { } //===----------------------------------------------------------------------===// +// NVVM redux op definitions +//===----------------------------------------------------------------------===// + +def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">; +def ReduxKindAdd : I32EnumAttrCase<"ADD", 1, "add">; +def ReduxKindAnd : I32EnumAttrCase<"AND", 2, "and">; +def ReduxKindMax : I32EnumAttrCase<"MAX", 3, "max">; +def ReduxKindMin : I32EnumAttrCase<"MIN", 4, "min">; +def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">; +def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">; +def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">; +def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">; + +/// Enum attribute of the different kinds. +def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", + [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, + ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def ReduxKindAttr : EnumAttr; + +def NVVM_ReduxOp : + NVVM_Op<"redux.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$val, + ReduxKindAttr:$kind, + I32:$mask_and_clamp)> { + string llvmBuilder = [{ + auto intId = getReduxIntrinsicId($_resultType, $kind); + $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); + }]; + let assemblyFormat = [{ + $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) + }]; +} + +//===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index feaf5ca..d7f1bb6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -25,6 +25,32 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; +static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, + NVVM::ReduxKind kind) { + if (!resultType->isIntegerTy(32)) + llvm_unreachable("unsupported data type for redux"); + + switch (kind) { + case NVVM::ReduxKind::ADD: + return llvm::Intrinsic::nvvm_redux_sync_add; + case NVVM::ReduxKind::UMAX: + return llvm::Intrinsic::nvvm_redux_sync_umax; + case NVVM::ReduxKind::UMIN: + return llvm::Intrinsic::nvvm_redux_sync_umin; + case NVVM::ReduxKind::AND: + return llvm::Intrinsic::nvvm_redux_sync_and; + case NVVM::ReduxKind::OR: + return llvm::Intrinsic::nvvm_redux_sync_or; + case NVVM::ReduxKind::XOR: + return llvm::Intrinsic::nvvm_redux_sync_xor; + case NVVM::ReduxKind::MAX: + return llvm::Intrinsic::nvvm_redux_sync_max; + case NVVM::ReduxKind::MIN: + return llvm::Intrinsic::nvvm_redux_sync_min; + } + llvm_unreachable("unknown redux kind"); +} + static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate) { diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 150f308..2e3b20b 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -310,6 +310,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr) { %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } + +// CHECK-LABEL: llvm.func @redux_sync +llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { + // CHECK: nvvm.redux.sync add %{{.*}} + %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync max %{{.*}} + %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync min %{{.*}} + %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umax %{{.*}} + %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umin %{{.*}} + %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync and %{{.*}} + %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync or %{{.*}} + %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync xor %{{.*}} + %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32 + llvm.return %r1 : i32 +} + + // ----- // expected-error@below {{attribute attached to unexpected op}} -- 2.7.4