Change all_reduce lowering to support 2D and 3D blocks.
authorChristian Sigg <csigg@google.com>
Tue, 1 Oct 2019 09:50:47 +0000 (02:50 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 Oct 2019 09:51:15 +0000 (02:51 -0700)
Perform second reduce only with first warp. This requires an additional __sync_threads(), but doesn't need special handling when the last warp is small. This simplifies support for block sizes that are not multiple of 32.

Supporting partial warp reduce will be done in a separate CL.

PiperOrigin-RevId: 272168917

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/test/mlir-cuda-runner/all-reduce.mlir

index a36a5b3..8145ec4 100644 (file)
@@ -123,30 +123,46 @@ private:
   //
   // First reduce the elements within a warp. The first thread of each warp
   // writes the intermediate result to shared memory. After synchronizing the
-  // block, each warp reduces all values from shared memory.
+  // block, the first warp reduces the values from shared memory. The result
+  // is broadcasted to all threads through shared memory.
   //
-  //     %warp_reduce = ... (see createWarpReduce)
-  //     %buffer = llvm.mlir.addressof @reduce_buffer : !llvm<"[32 x float]*">
+  //     %warp_reduce = `createWarpReduce(%operand)`
+  //     %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer
   //     %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
   //     %lane_id = nvvm.read.ptx.sreg.laneid  : !llvm.i32
-  //     %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i32
-  //     llvm.cond_br %is_first_lane, ^then, ^continue
-  //   ^then:
-  //     %warp_id = ... (see getWarpId)
-  //     %store_dst = llvm.getelementptr %buffer[%zero, %warp_id]
-  //     llvm.store %store_dst, %warp_reduce : !llvm.float
-  //     llvm.br ^continue
-  //   ^continue:
+  //     %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1
+  //     %thread_idx = `getLinearThreadIndex()` : !llvm.i32
+  //     llvm.cond_br %is_first_lane, ^then1, ^continue1
+  //   ^then1:
+  //     %warp_id = `getWarpId()`
+  //     %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id]
+  //     llvm.store %store_dst, %warp_reduce
+  //     llvm.br ^continue1
+  //   ^continue1:
   //     nvvm.barrier0
-  //     %load_src = llvm.getelementptr %buffer[%zero, %lane_id]
-  //     %value = llvm.load %load_src : !llvm.float
-  //     %result = ... (see createWarpReduce)
+  //     %num_warps = `getNumWarps()` : !llvm.i32
+  //     %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps
+  //     %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero]
+  //     llvm.cond_br %is_first_lane, ^then2, ^continue2
+  //   ^then2:
+  //     %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx]
+  //     %value = llvm.load %load_src
+  //     %result = `createWarpReduce(%value)`
+  //     llvm.store %result_ptr, %result
+  //     llvm.br ^continue2
+  //   ^continue2:
+  //     nvvm.barrier0
+  //     %result = llvm.load %result_ptr
+  //     return %result
+  //
   Value *createBlockReduce(Location loc, Value *operand,
                            ConversionPatternRewriter &rewriter) const {
     auto type = operand->getType().cast<LLVM::LLVMType>();
 
+    // Reduce elements within each warp to produce the intermediate results.
     Value *warpReduce = createWarpReduce(loc, operand, rewriter);
 
+    // Create shared memory array to store the warp reduction.
     auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
     assert(module && "op must belong to a module");
     Value *sharedMemPtr =
@@ -157,7 +173,59 @@ private:
     Value *laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
     Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
         loc, LLVM::ICmpPredicate::eq, laneId, zero);
+    Value *threadIdx = getLinearThreadIndex(loc, rewriter);
+
+    // Write the intermediate results to shared memory, using the first lane of
+    // each warp.
+    createPredicatedBlock(
+        loc, isFirstLane,
+        [&] {
+          Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
+          Value *storeDst = rewriter.create<LLVM::GEPOp>(
+              loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
+          rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
+        },
+        rewriter);
+
+    rewriter.create<NVVM::Barrier0Op>(loc);
+    Value *numWarps = getNumWarps(loc, rewriter);
+    Value *isValidWarp = rewriter.create<LLVM::ICmpOp>(
+        loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
+    Value *resultPtr = rewriter.create<LLVM::GEPOp>(
+        loc, type, sharedMemPtr, ArrayRef<Value *>({zero, zero}));
+
+    // Use the first numWarps threads to reduce the intermediate results from
+    // shared memory. The final result is written to shared memory again.
+    createPredicatedBlock(
+        loc, isValidWarp,
+        [&] {
+          Value *loadSrc = rewriter.create<LLVM::GEPOp>(
+              loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
+          Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
+          Value *result = createWarpReduce(loc, value, rewriter);
+          rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
+        },
+        rewriter);
+
+    rewriter.create<NVVM::Barrier0Op>(loc);
+    Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
 
+    return result;
+  }
+
+  // Creates an if-block skeleton to perform conditional execution of the
+  // instructions generated by predicatedOpsFactory.
+  //
+  //     llvm.cond_br %condition, ^then, ^continue
+  //   ^then:
+  //     ... code created in `predicatedOpsFactory()`
+  //     llvm.br ^continue
+  //   ^continue:
+  //
+  template <typename Func>
+  void createPredicatedBlock(Location loc, Value *condition,
+                             Func &&predicatedOpsFactory,
+                             ConversionPatternRewriter &rewriter) const {
     Block *currentBlock = rewriter.getInsertionBlock();
     auto currentPoint = rewriter.getInsertionPoint();
 
@@ -166,25 +234,15 @@ private:
 
     rewriter.setInsertionPointToEnd(currentBlock);
     rewriter.create<LLVM::CondBrOp>(
-        loc, llvm::makeArrayRef(isFirstLane),
+        loc, llvm::makeArrayRef(condition),
         ArrayRef<Block *>{thenBlock, continueBlock});
 
     rewriter.setInsertionPointToEnd(thenBlock);
-    Value *warpId = getWarpId(loc, rewriter);
-    Value *storeDst = rewriter.create<LLVM::GEPOp>(
-        loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
-    rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
+    predicatedOpsFactory();
     rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
                                 llvm::makeArrayRef(continueBlock));
 
     rewriter.setInsertionPointToStart(continueBlock);
-    rewriter.create<NVVM::Barrier0Op>(loc);
-    Value *loadSrc = rewriter.create<LLVM::GEPOp>(
-        loc, type, sharedMemPtr, ArrayRef<Value *>({zero, laneId}));
-    Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
-    Value *result = createWarpReduce(loc, value, rewriter);
-
-    return result;
   }
 
   // Creates an all_reduce across the warp. Creates a preamble
@@ -196,8 +254,8 @@ private:
   //
   //     %offset = llvm.mlir.constant(i : i32) : !llvm.i32
   //     %value = nvvm.shfl.sync.bfly
-  //        %active_mask, %operand, %offset, %mask_and_clamp : !llvm.float
-  //     %operand = llvm.fadd %operand, %value : !llvm.float
+  //        %active_mask, %operand, %offset, %mask_and_clamp
+  //     %operand = llvm.fadd %operand, %value
   //
   // Each thread returns the same result.
   //
@@ -244,23 +302,46 @@ private:
     return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
   }
 
-  // Returns the index of the warp within the block.
-  //
-  //     %warp_size = llvm.mlir.constant(32 : i32) : !llvm.i32
-  //     %thread_idx = nvvm.read.ptx.sreg.tid.x  : !llvm.i32
-  //     %warp_idx = llvm.sdiv %thread_idx, %warp_size : !llvm.i32
-  //
-  Value *getWarpId(Location loc, ConversionPatternRewriter &rewriter) const {
-    auto warpSize = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
-    auto threadIdx = getLinearThreadIndex(loc, rewriter);
-    return rewriter.create<LLVM::SDivOp>(loc, int32Type, threadIdx, warpSize);
-  }
-
+  // Returns the index of the thread within the block.
   Value *getLinearThreadIndex(Location loc,
                               ConversionPatternRewriter &rewriter) const {
-    // TODO(csigg): support 2- and 3-dimensional blocks.
-    return rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
+    Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+    Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+    Value *idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
+    Value *idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
+    Value *idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
+    Value *tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
+    Value *tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
+    Value *tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
+    return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
+  }
+
+  // Returns the number of warps in the block.
+  Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const {
+    auto blockSize = getBlockSize(loc, rewriter);
+    auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
+        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
+    auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
+        loc, int32Type, blockSize, warpSizeMinusOne);
+    return getDivideByWarpSize(biasedBlockSize, rewriter);
+  }
+
+  // Returns the number of threads in the block.
+  Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
+    Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+    Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+    Value *dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
+    Value *dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
+    return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
+  }
+
+  // Returns value divided by the warp size (i.e. 32).
+  Value *getDivideByWarpSize(Value *value,
+                             ConversionPatternRewriter &rewriter) const {
+    auto loc = value->getLoc();
+    auto warpSize = rewriter.create<LLVM::ConstantOp>(
+        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
+    return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize);
   }
 
   LLVM::LLVMType int32Type;
index 1bf1597..d607870 100644 (file)
@@ -2,24 +2,30 @@
 
 // CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
 func @main() {
-  %arg = alloc() : memref<128xf32>
-  %dst = memref_cast %arg : memref<128xf32> to memref<?xf32>
+  %arg = alloc() : memref<16x4x2xf32>
+  %dst = memref_cast %arg : memref<16x4x2xf32> to memref<?x?x?xf32>
   %zero = constant 0 : i32
   %one = constant 1 : index
-  %size = dim %dst, 0 : memref<?xf32>
-  call @mcuMemHostRegister(%dst, %zero) : (memref<?xf32>, i32) -> ()
+  %sx = dim %dst, 0 : memref<?x?x?xf32>
+  %sy = dim %dst, 1 : memref<?x?x?xf32>
+  %sz = dim %dst, 2 : memref<?x?x?xf32>
+  call @mcuMemHostRegister(%dst, %zero) : (memref<?x?x?xf32>, i32) -> ()
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
-             threads(%tx, %ty, %tz) in (%block_x = %size, %block_y = %one, %block_z = %one)
-             args(%kernel_dst = %dst) : memref<?xf32> {
-    %idx = index_cast %tx : index to i32
-    %val = sitofp %idx : i32 to f32
+             threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz)
+             args(%kernel_dst = %dst) : memref<?x?x?xf32> {
+    %t0 = muli %tz, %block_y : index
+    %t1 = addi %ty, %t0 : index
+    %t2 = muli %t1, %block_x : index
+    %idx = addi %tx, %t2 : index
+    %t3 = index_cast %idx : index to i32
+    %val = sitofp %t3 : i32 to f32
     %sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32)
-    store %sum, %kernel_dst[%tx] : memref<?xf32>
+    store %sum, %kernel_dst[%tx, %ty, %tz] : memref<?x?x?xf32>
     gpu.return
   }
-  call @mcuPrintFloat(%dst) : (memref<?xf32>) -> ()
+  call @mcuPrintFloat(%dst) : (memref<?x?x?xf32>) -> ()
   return
 }
 
-func @mcuMemHostRegister(%ptr : memref<?xf32>, %flags : i32)
-func @mcuPrintFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegister(%ptr : memref<?x?x?xf32>, %flags : i32)
+func @mcuPrintFloat(%ptr : memref<?x?x?xf32>)