From 1129931a625b19b57800c938f528b53f9ce737c1 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Tue, 1 Oct 2019 02:50:47 -0700 Subject: [PATCH] Change all_reduce lowering to support 2D and 3D blocks. Perform second reduce only with first warp. This requires an additional __sync_threads(), but doesn't need special handling when the last warp is small. This simplifies support for block sizes that are not multiple of 32. Supporting partial warp reduce will be done in a separate CL. PiperOrigin-RevId: 272168917 --- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 167 +++++++++++++++------ mlir/test/mlir-cuda-runner/all-reduce.mlir | 30 ++-- 2 files changed, 142 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index a36a5b3..8145ec4 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -123,30 +123,46 @@ private: // // First reduce the elements within a warp. The first thread of each warp // writes the intermediate result to shared memory. After synchronizing the - // block, each warp reduces all values from shared memory. + // block, the first warp reduces the values from shared memory. The result + // is broadcasted to all threads through shared memory. // - // %warp_reduce = ... (see createWarpReduce) - // %buffer = llvm.mlir.addressof @reduce_buffer : !llvm<"[32 x float]*"> + // %warp_reduce = `createWarpReduce(%operand)` + // %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer // %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 // %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 - // %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i32 - // llvm.cond_br %is_first_lane, ^then, ^continue - // ^then: - // %warp_id = ... (see getWarpId) - // %store_dst = llvm.getelementptr %buffer[%zero, %warp_id] - // llvm.store %store_dst, %warp_reduce : !llvm.float - // llvm.br ^continue - // ^continue: + // %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1 + // %thread_idx = `getLinearThreadIndex()` : !llvm.i32 + // llvm.cond_br %is_first_lane, ^then1, ^continue1 + // ^then1: + // %warp_id = `getWarpId()` + // %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id] + // llvm.store %store_dst, %warp_reduce + // llvm.br ^continue1 + // ^continue1: // nvvm.barrier0 - // %load_src = llvm.getelementptr %buffer[%zero, %lane_id] - // %value = llvm.load %load_src : !llvm.float - // %result = ... (see createWarpReduce) + // %num_warps = `getNumWarps()` : !llvm.i32 + // %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps + // %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero] + // llvm.cond_br %is_first_lane, ^then2, ^continue2 + // ^then2: + // %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx] + // %value = llvm.load %load_src + // %result = `createWarpReduce(%value)` + // llvm.store %result_ptr, %result + // llvm.br ^continue2 + // ^continue2: + // nvvm.barrier0 + // %result = llvm.load %result_ptr + // return %result + // Value *createBlockReduce(Location loc, Value *operand, ConversionPatternRewriter &rewriter) const { auto type = operand->getType().cast(); + // Reduce elements within each warp to produce the intermediate results. Value *warpReduce = createWarpReduce(loc, operand, rewriter); + // Create shared memory array to store the warp reduction. auto module = warpReduce->getDefiningOp()->getParentOfType(); assert(module && "op must belong to a module"); Value *sharedMemPtr = @@ -157,7 +173,59 @@ private: Value *laneId = rewriter.create(loc, int32Type); Value *isFirstLane = rewriter.create( loc, LLVM::ICmpPredicate::eq, laneId, zero); + Value *threadIdx = getLinearThreadIndex(loc, rewriter); + + // Write the intermediate results to shared memory, using the first lane of + // each warp. + createPredicatedBlock( + loc, isFirstLane, + [&] { + Value *warpId = getDivideByWarpSize(threadIdx, rewriter); + Value *storeDst = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, warpId})); + rewriter.create(loc, warpReduce, storeDst); + }, + rewriter); + + rewriter.create(loc); + Value *numWarps = getNumWarps(loc, rewriter); + Value *isValidWarp = rewriter.create( + loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); + Value *resultPtr = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, zero})); + + // Use the first numWarps threads to reduce the intermediate results from + // shared memory. The final result is written to shared memory again. + createPredicatedBlock( + loc, isValidWarp, + [&] { + Value *loadSrc = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); + Value *value = rewriter.create(loc, type, loadSrc); + Value *result = createWarpReduce(loc, value, rewriter); + rewriter.create(loc, result, resultPtr); + }, + rewriter); + + rewriter.create(loc); + Value *result = rewriter.create(loc, type, resultPtr); + return result; + } + + // Creates an if-block skeleton to perform conditional execution of the + // instructions generated by predicatedOpsFactory. + // + // llvm.cond_br %condition, ^then, ^continue + // ^then: + // ... code created in `predicatedOpsFactory()` + // llvm.br ^continue + // ^continue: + // + template + void createPredicatedBlock(Location loc, Value *condition, + Func &&predicatedOpsFactory, + ConversionPatternRewriter &rewriter) const { Block *currentBlock = rewriter.getInsertionBlock(); auto currentPoint = rewriter.getInsertionPoint(); @@ -166,25 +234,15 @@ private: rewriter.setInsertionPointToEnd(currentBlock); rewriter.create( - loc, llvm::makeArrayRef(isFirstLane), + loc, llvm::makeArrayRef(condition), ArrayRef{thenBlock, continueBlock}); rewriter.setInsertionPointToEnd(thenBlock); - Value *warpId = getWarpId(loc, rewriter); - Value *storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); - rewriter.create(loc, warpReduce, storeDst); + predicatedOpsFactory(); rewriter.create(loc, ArrayRef(), llvm::makeArrayRef(continueBlock)); rewriter.setInsertionPointToStart(continueBlock); - rewriter.create(loc); - Value *loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, laneId})); - Value *value = rewriter.create(loc, type, loadSrc); - Value *result = createWarpReduce(loc, value, rewriter); - - return result; } // Creates an all_reduce across the warp. Creates a preamble @@ -196,8 +254,8 @@ private: // // %offset = llvm.mlir.constant(i : i32) : !llvm.i32 // %value = nvvm.shfl.sync.bfly - // %active_mask, %operand, %offset, %mask_and_clamp : !llvm.float - // %operand = llvm.fadd %operand, %value : !llvm.float + // %active_mask, %operand, %offset, %mask_and_clamp + // %operand = llvm.fadd %operand, %value // // Each thread returns the same result. // @@ -244,23 +302,46 @@ private: return rewriter.create(loc, globalOp); } - // Returns the index of the warp within the block. - // - // %warp_size = llvm.mlir.constant(32 : i32) : !llvm.i32 - // %thread_idx = nvvm.read.ptx.sreg.tid.x : !llvm.i32 - // %warp_idx = llvm.sdiv %thread_idx, %warp_size : !llvm.i32 - // - Value *getWarpId(Location loc, ConversionPatternRewriter &rewriter) const { - auto warpSize = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - auto threadIdx = getLinearThreadIndex(loc, rewriter); - return rewriter.create(loc, int32Type, threadIdx, warpSize); - } - + // Returns the index of the thread within the block. Value *getLinearThreadIndex(Location loc, ConversionPatternRewriter &rewriter) const { - // TODO(csigg): support 2- and 3-dimensional blocks. - return rewriter.create(loc, int32Type); + Value *dimX = rewriter.create(loc, int32Type); + Value *dimY = rewriter.create(loc, int32Type); + Value *idX = rewriter.create(loc, int32Type); + Value *idY = rewriter.create(loc, int32Type); + Value *idZ = rewriter.create(loc, int32Type); + Value *tmp1 = rewriter.create(loc, int32Type, idZ, dimY); + Value *tmp2 = rewriter.create(loc, int32Type, tmp1, idY); + Value *tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); + return rewriter.create(loc, int32Type, tmp3, idX); + } + + // Returns the number of warps in the block. + Value *getNumWarps(Location loc, ConversionPatternRewriter &rewriter) const { + auto blockSize = getBlockSize(loc, rewriter); + auto warpSizeMinusOne = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + auto biasedBlockSize = rewriter.create( + loc, int32Type, blockSize, warpSizeMinusOne); + return getDivideByWarpSize(biasedBlockSize, rewriter); + } + + // Returns the number of threads in the block. + Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { + Value *dimX = rewriter.create(loc, int32Type); + Value *dimY = rewriter.create(loc, int32Type); + Value *dimZ = rewriter.create(loc, int32Type); + Value *dimXY = rewriter.create(loc, int32Type, dimX, dimY); + return rewriter.create(loc, int32Type, dimXY, dimZ); + } + + // Returns value divided by the warp size (i.e. 32). + Value *getDivideByWarpSize(Value *value, + ConversionPatternRewriter &rewriter) const { + auto loc = value->getLoc(); + auto warpSize = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); + return rewriter.create(loc, int32Type, value, warpSize); } LLVM::LLVMType int32Type; diff --git a/mlir/test/mlir-cuda-runner/all-reduce.mlir b/mlir/test/mlir-cuda-runner/all-reduce.mlir index 1bf1597..d607870 100644 --- a/mlir/test/mlir-cuda-runner/all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce.mlir @@ -2,24 +2,30 @@ // CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03] func @main() { - %arg = alloc() : memref<128xf32> - %dst = memref_cast %arg : memref<128xf32> to memref + %arg = alloc() : memref<16x4x2xf32> + %dst = memref_cast %arg : memref<16x4x2xf32> to memref %zero = constant 0 : i32 %one = constant 1 : index - %size = dim %dst, 0 : memref - call @mcuMemHostRegister(%dst, %zero) : (memref, i32) -> () + %sx = dim %dst, 0 : memref + %sy = dim %dst, 1 : memref + %sz = dim %dst, 2 : memref + call @mcuMemHostRegister(%dst, %zero) : (memref, i32) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) - threads(%tx, %ty, %tz) in (%block_x = %size, %block_y = %one, %block_z = %one) - args(%kernel_dst = %dst) : memref { - %idx = index_cast %tx : index to i32 - %val = sitofp %idx : i32 to f32 + threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) + args(%kernel_dst = %dst) : memref { + %t0 = muli %tz, %block_y : index + %t1 = addi %ty, %t0 : index + %t2 = muli %t1, %block_x : index + %idx = addi %tx, %t2 : index + %t3 = index_cast %idx : index to i32 + %val = sitofp %t3 : i32 to f32 %sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32) - store %sum, %kernel_dst[%tx] : memref + store %sum, %kernel_dst[%tx, %ty, %tz] : memref gpu.return } - call @mcuPrintFloat(%dst) : (memref) -> () + call @mcuPrintFloat(%dst) : (memref) -> () return } -func @mcuMemHostRegister(%ptr : memref, %flags : i32) -func @mcuPrintFloat(%ptr : memref) +func @mcuMemHostRegister(%ptr : memref, %flags : i32) +func @mcuPrintFloat(%ptr : memref) -- 2.7.4