Fix maskAndClamp in gpu.all_reduce.
authorChristian Sigg <csigg@google.com>
Fri, 13 Dec 2019 07:06:06 +0000 (23:06 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Dec 2019 23:28:58 +0000 (15:28 -0800)
The clamp value determines the returned predicate. Previously, the clamp value was fixed to 31 and the predicate was therefore always true. This is incorrect for partial warp reductions, but went unnoticed because the returned values happened to be zero (but it could be anything).

PiperOrigin-RevId: 285343160

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

index ccad2cd..e4bdd7c 100644 (file)
@@ -308,8 +308,6 @@ private:
                           ConversionPatternRewriter &rewriter) const {
     Value *warpSize = rewriter.create<LLVM::ConstantOp>(
         loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
-    Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
     Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
         loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
     auto type = operand->getType().cast<LLVM::LLVMType>();
@@ -326,6 +324,9 @@ private:
               loc, int32Type,
               rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
               one);
+          // Clamp lane: `activeWidth - 1`
+          Value *maskAndClamp =
+              rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one);
           auto dialect = lowering.getDialect();
           auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
           auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
@@ -363,6 +364,8 @@ private:
           Value *value = operand;
           Value *activeMask = rewriter.create<LLVM::ConstantOp>(
               loc, int32Type, rewriter.getI32IntegerAttr(~0u));
+          Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
+              loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
           for (int i = 1; i < kWarpSize; i <<= 1) {
             Value *offset = rewriter.create<LLVM::ConstantOp>(
                 loc, int32Type, rewriter.getI32IntegerAttr(i));