From 0660f3c5a0a083cef78f39b3105a9a8e27cf5095 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sat, 9 Jul 2022 18:36:39 +0000 Subject: [PATCH] [mlir][vector] Relax reduction distribution pattern Support distributing reductions with vector size multiple of the warp size. Differential Revision: https://reviews.llvm.org/D129387 --- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 20 ++++++------ .../Dialect/Vector/vector-warp-distribute.mlir | 36 +++++++++++++++++++++- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index bf6e222..2b96358 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -798,7 +798,7 @@ struct WarpOpReduction : public OpRewritePattern { return rewriter.notifyMatchFailure( warpOp, "Only rank 1 reductions can be distributed."); // Only warp_size-sized vectors supported. - if (static_cast(vectorType.getShape()[0]) != warpOp.getWarpSize()) + if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Reduction vector dimension must match was size."); // Only f32 and i32 element types are supported. @@ -808,24 +808,26 @@ struct WarpOpReduction : public OpRewritePattern { warpOp, "Reduction distribution currently only supports 32bits types."); - Location yieldLoc = yieldOperand->getOwner()->getLoc(); - + int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); // Return vector that will be reduced from the WarpExecuteOnLane0Op. unsigned operandIndex = yieldOperand->getOperandNumber(); SmallVector yieldValues = {reductionOp.getVector()}; - SmallVector retTypes = {VectorType::get({1}, reductionOp.getType())}; + SmallVector retTypes = { + VectorType::get({numElements}, reductionOp.getType())}; unsigned numResults = warpOp.getNumResults(); WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, yieldValues, retTypes); rewriter.setInsertionPointAfter(newWarpOp); - // Every lane has one scalar value. These should be reduced. Value laneValVec = newWarpOp.getResult(numResults); - Value laneVal = rewriter.create(yieldLoc, laneValVec, 0); - laneVal = - distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, + // First reduce on a single thread. + Value perLaneReduction = rewriter.create( + reductionOp.getLoc(), reductionOp.getKind(), laneValVec); + // Then distribute across threads. + Value fullReduce = + distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction, reductionOp.getKind(), newWarpOp.getWarpSize()); - newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); + newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 718a7bf..82f6299 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -509,5 +509,39 @@ func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref %5 = vector.broadcast %4 : f32 to vector vector.transfer_write %5, %m1[] : vector, memref } - return + return +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_reduction_large( +// CHECK-PROP-SAME: %[[laneid:.*]]: index) +// CHECK-PROP-DAG: %[[c1:.*]] = arith.constant 1 : i32 +// CHECK-PROP-DAG: %[[c2:.*]] = arith.constant 2 : i32 +// CHECK-PROP-DAG: %[[c4:.*]] = arith.constant 4 : i32 +// CHECK-PROP-DAG: %[[c8:.*]] = arith.constant 8 : i32 +// CHECK-PROP-DAG: %[[c16:.*]] = arith.constant 16 : i32 +// CHECK-PROP-DAG: %[[c32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[warp_op:.*]] = vector.warp_execute_on_lane_0(%[[laneid]])[32] -> (vector<2xf32>) { +// CHECK-PROP: vector.yield %{{.*}} : vector<64xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[a:.*]] = vector.reduction , %[[warp_op]] : vector<2xf32> into f32 +// CHECK-PROP: %[[r0:.*]], %{{.*}} = gpu.shuffle xor %[[a]], %[[c1]], %[[c32]] +// CHECK-PROP: %[[a0:.*]] = arith.addf %[[a]], %[[r0]] +// CHECK-PROP: %[[r1:.*]], %{{.*}} = gpu.shuffle xor %[[a0]], %[[c2]], %[[c32]] +// CHECK-PROP: %[[a1:.*]] = arith.addf %[[a0]], %[[r1]] +// CHECK-PROP: %[[r2:.*]], %{{.*}} = gpu.shuffle xor %[[a1]], %[[c4]], %[[c32]] +// CHECK-PROP: %[[a2:.*]] = arith.addf %[[a1]], %[[r2]] +// CHECK-PROP: %[[r3:.*]], %{{.*}} = gpu.shuffle xor %[[a2]], %[[c8]], %[[c32]] +// CHECK-PROP: %[[a3:.*]] = arith.addf %[[a2]], %[[r3]] +// CHECK-PROP: %[[r4:.*]], %{{.*}} = gpu.shuffle xor %[[a3]], %[[c16]], %[[c32]] +// CHECK-PROP: %[[a4:.*]] = arith.addf %[[a3]], %[[r4]] +// CHECK-PROP: return %[[a4]] : f32 +func.func @vector_reduction_large(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.reduction , %0 : vector<64xf32> into f32 + vector.yield %1 : f32 + } + return %r : f32 } -- 2.7.4