From: Christian Sigg Date: Fri, 13 Dec 2019 07:06:06 +0000 (-0800) Subject: Fix maskAndClamp in gpu.all_reduce. X-Git-Tag: llvmorg-11-init~1466^2~80 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8846557672d6a490b500b1c24e90a8effcb38901;p=platform%2Fupstream%2Fllvm.git Fix maskAndClamp in gpu.all_reduce. The clamp value determines the returned predicate. Previously, the clamp value was fixed to 31 and the predicate was therefore always true. This is incorrect for partial warp reductions, but went unnoticed because the returned values happened to be zero (but it could be anything). PiperOrigin-RevId: 285343160 --- diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index ccad2cd..e4bdd7c 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -308,8 +308,6 @@ private: ConversionPatternRewriter &rewriter) const { Value *warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - Value *maskAndClamp = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); Value *isPartialWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); auto type = operand->getType().cast(); @@ -326,6 +324,9 @@ private: loc, int32Type, rewriter.create(loc, int32Type, one, activeWidth), one); + // Clamp lane: `activeWidth - 1` + Value *maskAndClamp = + rewriter.create(loc, int32Type, activeWidth, one); auto dialect = lowering.getDialect(); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); @@ -363,6 +364,8 @@ private: Value *value = operand; Value *activeMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~0u)); + Value *maskAndClamp = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i));