From f68ac464d818629e0fe10c23b44ac782d64a12d2 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 12 Dec 2019 01:27:27 -0800 Subject: [PATCH] Switch from shfl.bfly to shfl.down. Both work for the current use case, but the latter allows implementing prefix sums and is a little easier to understand for partial warps. PiperOrigin-RevId: 285145287 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 ++++---- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 4 ++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 ++-- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 10 +++++----- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 4 ++-- mlir/test/Dialect/LLVMIR/invalid.mlir | 6 +++--- mlir/test/Dialect/LLVMIR/nvvm.mlir | 16 ++++++++-------- mlir/test/Target/nvvmir.mlir | 16 ++++++++-------- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index c6e89af..9a0e43e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -90,8 +90,8 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; } -def NVVM_ShflBflyOp : - NVVM_Op<"shfl.sync.bfly">, +def NVVM_ShflDownOp : + NVVM_Op<"shfl.sync.down">, Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_Type:$dst, LLVM_Type:$val, @@ -99,12 +99,12 @@ def NVVM_ShflBflyOp : LLVM_Type:$mask_and_clamp, OptionalAttr:$return_value_and_is_valid)> { string llvmBuilder = [{ - auto intId = getShflBflyIntrinsicId( + auto intId = getShflDownIntrinsicId( $_resultType, static_cast($return_value_and_is_valid)); $res = createIntrinsicCall(builder, intId, {$dst, $val, $offset, $mask_and_clamp}); }]; - let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }]; + let parser = [{ return parseNVVMShflSyncDownOp(parser, result); }]; let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; let verifier = [{ if (!getAttrOfType("return_value_and_is_valid")) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index ccad2cd..a7474cf 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -337,7 +337,7 @@ private: for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value *shfl = rewriter.create( + Value *shfl = rewriter.create( loc, shflTy, activeMask, value, offset, maskAndClamp, returnValueAndIsValidAttr); Value *isActiveSrcLane = rewriter.create( @@ -366,7 +366,7 @@ private: for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value *shflValue = rewriter.create( + Value *shflValue = rewriter.create( loc, type, activeMask, value, offset, maskAndClamp, /*return_value_and_is_valid=*/UnitAttr()); value = accumFactory(loc, value, shflValue, rewriter); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 0b10391..32bb3a6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -70,9 +70,9 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { } // ::= -// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask` +// `llvm.nvvm.shfl.sync.down %dst, %val, %offset, %clamp_and_mask` // ({return_value_and_is_valid})? : result_type -static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser, +static ParseResult parseNVVMShflSyncDownOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type resultType; diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 606e91b..cdaa384 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -44,15 +44,15 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, return builder.CreateCall(fn, args); } -static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, +static llvm::Intrinsic::ID getShflDownIntrinsicId(llvm::Type *resultType, bool withPredicate) { if (withPredicate) { resultType = cast(resultType)->getElementType(0); - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p + : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; } - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 + : llvm::Intrinsic::nvvm_shfl_sync_down_i32; } class ModuleTranslation : public LLVM::ModuleTranslation { diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 30bba48..a68c775 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -44,7 +44,7 @@ module attributes {gpu.kernel_module} { attributes { gpu.kernel } { %arg0 = constant 1.0 : f32 // TODO(csigg): Check full IR expansion once lowering has settled. - // CHECK: nvvm.shfl.sync.bfly + // CHECK: nvvm.shfl.sync.down // CHECK: nvvm.barrier0 // CHECK: llvm.fadd %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32) @@ -61,7 +61,7 @@ module attributes {gpu.kernel_module} { attributes { gpu.kernel } { %arg0 = constant 1 : i32 // TODO(csigg): Check full IR expansion once lowering has settled. - // CHECK: nvvm.shfl.sync.bfly + // CHECK: nvvm.shfl.sync.down // CHECK: nvvm.barrier0 %result = "gpu.all_reduce"(%arg0) ({ ^bb(%lhs : i32, %rhs : i32): diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index c7d4ff4..d6df074 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -268,7 +268,7 @@ func @null_non_llvm_type() { // CHECK-LABEL: @nvvm_invalid_shfl_pred_1 func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32 + %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32 } // ----- @@ -276,7 +276,7 @@ func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll // CHECK-LABEL: @nvvm_invalid_shfl_pred_2 func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }"> + %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }"> } // ----- @@ -284,7 +284,7 @@ func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll // CHECK-LABEL: @nvvm_invalid_shfl_pred_3 func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) { // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }"> + %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }"> } // ----- diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 3858ace..50a7f7d 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -37,20 +37,20 @@ func @llvm.nvvm.barrier0() { func @nvvm_shfl( %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 { - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32 - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32 - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float - %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float + // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32 + %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 : !llvm.i32 + // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float + %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 : !llvm.float llvm.return %0 : !llvm.i32 } func @nvvm_shfl_pred( %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> { - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }"> - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }"> - %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> + // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }"> + %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> + // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }"> + %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> llvm.return %0 : !llvm<"{ i32, i1 }"> } diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir index 2e63ecd..e97603a 100644 --- a/mlir/test/Target/nvvmir.mlir +++ b/mlir/test/Target/nvvmir.mlir @@ -41,20 +41,20 @@ llvm.func @llvm.nvvm.barrier0() { llvm.func @nvvm_shfl( %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32, %3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 { - // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32 - // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float + // CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %6 = nvvm.shfl.sync.down %0, %3, %1, %2 : !llvm.i32 + // CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %7 = nvvm.shfl.sync.down %0, %4, %1, %2 : !llvm.float llvm.return %6 : !llvm.i32 } llvm.func @nvvm_shfl_pred( %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32, %3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> { - // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> - // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> + // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %6 = nvvm.shfl.sync.down %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }"> + // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %7 = nvvm.shfl.sync.down %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }"> llvm.return %6 : !llvm<"{ i32, i1 }"> } -- 2.7.4