From 116dac00baa6870aec2a2b469b2d6f95c2fbb316 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 26 Sep 2019 00:17:13 -0700 Subject: [PATCH] Add AllReduceOp to GPU dialect with lowering to NVVM. The reduction operation is currently fixed to "add", and the scope is fixed to "workgroup". The implementation is currently limited to sizes that are multiple 32 (warp size) and no larger than 1024. PiperOrigin-RevId: 271290265 --- mlir/include/mlir/Dialect/GPU/GPUOps.td | 19 +++ .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 171 ++++++++++++++++++++- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 7 + mlir/test/Dialect/GPU/ops.mlir | 3 + mlir/test/mlir-cuda-runner/all-reduce.mlir | 25 +++ 5 files changed, 223 insertions(+), 2 deletions(-) create mode 100644 mlir/test/mlir-cuda-runner/all-reduce.mlir diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 75e0f6b..aa8046e 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -59,4 +59,23 @@ def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>, let printer = [{ p << getOperationName(); }]; } +def gpu_AllReduce : GPU_Op<"all_reduce", [SameOperandsAndResultType]>, + Arguments<(ins AnyType)>, Results<(outs AnyType)> { + let summary = "Reduce values among workgroup."; + let description = [{ + The "all_reduce" op reduces the value of every invocation across a local + workgroup. + + For example, + ``` + %1 = gpu.all_reduce %0 : f32 + ``` + computes the sum of each invocation's %0 value. The value of %1 is always + equal for all invocations of a local workgroup. + + Either none or all invocations of a local workgroup need to execute this op + in convergence. + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 0028bad..f5f5f99 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -103,6 +103,173 @@ public: } }; +// Converts all_reduce op to LLVM/NVVM ops. +struct GPUAllReduceOpLowering : public LLVMOpLowering { + explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::AllReduce::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), + int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Value *result = createBlockReduce(op->getLoc(), operands.front(), rewriter); + rewriter.replaceOp(op, {result}); + return matchSuccess(); + } + +private: + // Creates an all_reduce across the local workgroup. + // + // First reduce the elements within a subgroup (i.e. warp). The first + // invocation of each subgroup writes the intermediate result to shared + // memory. After synchronizing the local workgroup, each subgroup reduces all + // values from shared memory. + // + // %warp_reduce = ... (see createWarpReduce) + // %buffer = llvm.mlir.addressof @reduce_buffer : !llvm<"[32 x float]*"> + // %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 + // %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 + // %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i32 + // llvm.cond_br %is_first_lane, ^then, ^continue + // ^then: + // %warp_id = ... (see getWarpId) + // %store_dst = llvm.getelementptr %buffer[%zero, %warp_id] + // llvm.store %store_dst, %warp_reduce : !llvm.float + // llvm.br ^continue + // ^continue: + // nvvm.barrier0 + // %load_src = llvm.getelementptr %buffer[%zero, %lane_id] + // %value = llvm.load %load_src : !llvm.float + // %result = ... (see createWarpReduce) + Value *createBlockReduce(Location loc, Value *operand, + ConversionPatternRewriter &rewriter) const { + auto type = operand->getType().cast(); + + Value *warpReduce = createWarpReduce(loc, operand, rewriter); + + auto module = warpReduce->getDefiningOp()->getParentOfType(); + assert(module && "op must belong to a module"); + Value *sharedMemPtr = + createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); + + Value *zero = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(0u)); + Value *laneId = rewriter.create(loc, int32Type); + Value *isFirstLane = rewriter.create( + loc, LLVM::ICmpPredicate::eq, laneId, zero); + + Block *currentBlock = rewriter.getInsertionBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + + Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); + Block *continueBlock = rewriter.splitBlock(thenBlock, currentPoint); + + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create( + loc, llvm::makeArrayRef(isFirstLane), + ArrayRef{thenBlock, continueBlock}); + + rewriter.setInsertionPointToEnd(thenBlock); + Value *warpId = getWarpId(loc, rewriter); + Value *storeDst = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, warpId})); + rewriter.create(loc, warpReduce, storeDst); + rewriter.create(loc, ArrayRef(), + llvm::makeArrayRef(continueBlock)); + + rewriter.setInsertionPointToStart(continueBlock); + rewriter.create(loc); + Value *loadSrc = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, laneId})); + Value *value = rewriter.create(loc, type, loadSrc); + Value *result = createWarpReduce(loc, value, rewriter); + + return result; + } + + // Creates an all_reduce across the subgroup. Creates a preamble + // + // %active_mask = llvm.mlir.constant(-1 : i32) : !llvm.i32 + // %mask_and_clamp = llvm.mlir.constant(31 : i32) : !llvm.i32 + // + // plus the accumulation for i = 1, 2, 4, 8, 16: + // + // %offset = llvm.mlir.constant(i : i32) : !llvm.i32 + // %value = nvvm.shfl.sync.bfly + // %active_mask, %operand, %offset, %mask_and_clamp : !llvm.float + // %operand = llvm.fadd %operand, %value : !llvm.float + // + // Each invocation returns the same result. + // + // Note: this currently only supports reducing exactly 32 values. + Value *createWarpReduce(Location loc, Value *operand, + ConversionPatternRewriter &rewriter) const { + // TODO(csigg): Generalize to partial warps and other types of accumulation. + static_assert(kWarpSize == 32, "Only warp size of 32 is supported."); + auto activeMask = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(~0u)); + auto maskAndClamp = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + + auto resultType = operand->getType(); + for (int i = 1; i < kWarpSize; i <<= 1) { + auto offset = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(i)); + auto value = rewriter.create( + loc, resultType, activeMask, operand, offset, maskAndClamp); + operand = rewriter.create(loc, resultType, operand, value); + } + return operand; + } + + // Creates a global array stored in shared memory. + // + // llvm.mlir.global @reduce_buffer() + // {addr_space = 3 : i32} : !llvm<"[32 x float]"> + // + Value *createSharedMemoryArray(Location loc, ModuleOp module, + LLVM::LLVMType elementType, int numElements, + ConversionPatternRewriter &rewriter) const { + OpBuilder builder(module.getBodyRegion()); + + auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); + StringRef name = "reduce_buffer"; + auto addrSpace = + builder.getNamedAttr("addr_space", builder.getI32IntegerAttr(3)); + auto globalOp = builder.create( + loc, arrayType.cast(), + /*isConstant=*/false, name, /*value=*/Attribute(), + llvm::makeArrayRef(addrSpace)); + + return rewriter.create(loc, globalOp); + } + + // Returns the index of the subgroup within the local workgroup. + // + // %warp_size = llvm.mlir.constant(32 : i32) : !llvm.i32 + // %thread_idx = nvvm.read.ptx.sreg.tid.x : !llvm.i32 + // %warp_idx = llvm.sdiv %thread_idx, %warp_size : !llvm.i32 + // + Value *getWarpId(Location loc, ConversionPatternRewriter &rewriter) const { + auto warpSize = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); + auto threadIdx = getLinearThreadIndex(loc, rewriter); + return rewriter.create(loc, int32Type, threadIdx, warpSize); + } + + Value *getLinearThreadIndex(Location loc, + ConversionPatternRewriter &rewriter) const { + // TODO(csigg): support 2- and 3-dimensional blocks. + return rewriter.create(loc, int32Type); + } + + LLVM::LLVMType int32Type; + + // TODO(csigg): Support other warp sizes. + static constexpr int kWarpSize = 32; +}; + // A pass that replaces all occurences of GPU device operations with their // corresponding NVVM equivalent. // @@ -126,8 +293,8 @@ public: GPUIndexIntrinsicOpLowering, GPUIndexIntrinsicOpLowering>( - converter); + NVVM::GridDimYOp, NVVM::GridDimZOp>, + GPUAllReduceOpLowering>(converter); ConversionTarget target(getContext()); target.addLegalDialect(); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 0263737..6168be3 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -32,6 +32,13 @@ module attributes {gpu.kernel_module} { // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) + %one = constant 1.0 : f32 + // TODO(csigg): Check full IR expansion once lowering has settled. + // CHECK: nvvm.shfl.sync.bfly + // CHECK: nvvm.barrier0 + // CHECK: nvvm.shfl.sync.bfly + %result = "gpu.all_reduce"(%one) {scope = "workgroup", kernel = "add"} : (f32) -> (f32) + std.return } } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index b78da95..7c8f682 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -76,6 +76,9 @@ func @kernel_1(%arg0 : f32, %arg1 : memref) %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) + %one = constant 1.0 : f32 + %sum = "gpu.all_reduce"(%one) : (f32) -> (f32) + "some_op"(%bIdX, %tIdX) : (index, index) -> () %42 = load %arg1[%bIdX] : memref return diff --git a/mlir/test/mlir-cuda-runner/all-reduce.mlir b/mlir/test/mlir-cuda-runner/all-reduce.mlir new file mode 100644 index 0000000..1bf1597 --- /dev/null +++ b/mlir/test/mlir-cuda-runner/all-reduce.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s + +// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03] +func @main() { + %arg = alloc() : memref<128xf32> + %dst = memref_cast %arg : memref<128xf32> to memref + %zero = constant 0 : i32 + %one = constant 1 : index + %size = dim %dst, 0 : memref + call @mcuMemHostRegister(%dst, %zero) : (memref, i32) -> () + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) + threads(%tx, %ty, %tz) in (%block_x = %size, %block_y = %one, %block_z = %one) + args(%kernel_dst = %dst) : memref { + %idx = index_cast %tx : index to i32 + %val = sitofp %idx : i32 to f32 + %sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32) + store %sum, %kernel_dst[%tx] : memref + gpu.return + } + call @mcuPrintFloat(%dst) : (memref) -> () + return +} + +func @mcuMemHostRegister(%ptr : memref, %flags : i32) +func @mcuPrintFloat(%ptr : memref) -- 2.7.4